Skip to content

Commit

Permalink
feat: add SAFETY_DB_DIR env var to the scan command (#523)
Browse files Browse the repository at this point in the history
  • Loading branch information
yeisonvargasf authored May 1, 2024
1 parent 4ef66e0 commit a1692d5
Show file tree
Hide file tree
Showing 8 changed files with 85 additions and 15 deletions.
2 changes: 1 addition & 1 deletion safety/auth/cli_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def update_token(tokens, **kwargs):
try:
openid_config = client_session.get(url=OPENID_CONFIG_URL, timeout=REQUEST_TIMEOUT).json()
except Exception as e:
LOG.exception('Unable to load the openID config: %s', e)
LOG.debug('Unable to load the openID config: %s', e)
openid_config = {}

client_session.metadata["token_endpoint"] = openid_config.get("token_endpoint",
Expand Down
4 changes: 4 additions & 0 deletions safety/auth/models.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from dataclasses import dataclass
import os
from typing import Any, Optional

from authlib.integrations.base_client import BaseOAuth
Expand Down Expand Up @@ -26,6 +27,9 @@ class Auth:
email_verified: bool = False

def is_valid(self) -> bool:
if os.getenv("SAFETY_DB_DIR"):
return True

if not self.client:
return False

Expand Down
11 changes: 10 additions & 1 deletion safety/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,11 +49,20 @@

LOG = logging.getLogger(__name__)


def configure_logger(ctx, param, debug):
level = logging.CRITICAL

if debug:
level = logging.DEBUG

logging.basicConfig(format='%(asctime)s %(name)s => %(message)s', level=level)

@click.group(cls=SafetyCLILegacyGroup, help=CLI_MAIN_INTRODUCTION, epilog=DEFAULT_EPILOG)
@auth_options()
@proxy_options
@click.option('--disable-optional-telemetry', default=False, is_flag=True, show_default=True, help=CLI_DISABLE_OPTIONAL_TELEMETRY_DATA_HELP)
@click.option('--debug', default=False, help=CLI_DEBUG_HELP)
@click.option('--debug', default=False, help=CLI_DEBUG_HELP, callback=configure_logger)
@click.version_option(version=get_safety_version())
@click.pass_context
@inject_session
Expand Down
27 changes: 20 additions & 7 deletions safety/safety.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,12 +199,21 @@ def post_results(session, safety_json, policy_file):
return {}


def fetch_database_file(path, db_name, ecosystem: Ecosystem = Ecosystem.PYTHON):
full_path = os.path.join(path, db_name)
if not os.path.exists(full_path):
def fetch_database_file(path: str, db_name: str, cached = 0,
ecosystem: Optional[Ecosystem] = None):
full_path = (Path(path) / (ecosystem.value if ecosystem else '') / db_name).expanduser().resolve()

if not full_path.exists():
raise DatabaseFileNotFoundError(db=path)

with open(full_path) as f:
return json.loads(f.read())
data = json.loads(f.read())

if cached:
LOG.info('Writing %s to cache because cached value was %s', db_name, cached)
write_to_cache(db_name, data)

return data


def is_valid_database(db) -> bool:
Expand All @@ -218,7 +227,8 @@ def is_valid_database(db) -> bool:


def fetch_database(session, full=False, db=False, cached=0, telemetry=True,
ecosystem: Ecosystem = Ecosystem.PYTHON, from_cache=True):
ecosystem: Optional[Ecosystem] = None, from_cache=True):

if session.is_using_auth_credentials():
mirrors = API_MIRRORS
elif db:
Expand All @@ -230,10 +240,13 @@ def fetch_database(session, full=False, db=False, cached=0, telemetry=True,
for mirror in mirrors:
# mirror can either be a local path or a URL
if is_a_remote_mirror(mirror):
if ecosystem is None:
ecosystem = Ecosystem.PYTHON
data = fetch_database_url(session, mirror, db_name=db_name, cached=cached,
telemetry=telemetry, ecosystem=ecosystem, from_cache=from_cache)
else:
data = fetch_database_file(mirror, db_name=db_name, ecosystem=ecosystem)
data = fetch_database_file(mirror, db_name=db_name, cached=cached,
ecosystem=ecosystem)
if data:
if is_valid_database(data):
return data
Expand Down Expand Up @@ -1000,7 +1013,7 @@ def get_licenses(*, session=None, db_mirror=False, cached=0, telemetry=True):
licenses = fetch_database_url(session, mirror, db_name=db_name, cached=cached,
telemetry=telemetry)
else:
licenses = fetch_database_file(mirror, db_name=db_name)
licenses = fetch_database_file(mirror, db_name=db_name, ecosystem=None)
if licenses:
return licenses
raise DatabaseFetchError()
Expand Down
13 changes: 9 additions & 4 deletions safety/scan/decorators.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from functools import wraps
import logging
import os
from pathlib import Path
from random import randint
import sys
Expand Down Expand Up @@ -135,11 +136,15 @@ def inner(ctx, policy_file_path: Optional[Path], target: Path,
if ctx.obj.auth.client.get_authentication_type() == "api_key":
details = {"Account": f"API key used"}
else:
content = ctx.obj.auth.email
if ctx.obj.auth.name != ctx.obj.auth.email:
content = f"{ctx.obj.auth.name}, {ctx.obj.auth.email}"

details = {"Account": f"{content} {render_email_note(ctx.obj.auth)}"}
if ctx.obj.auth.client.get_authentication_type() == "token":
content = ctx.obj.auth.email
if ctx.obj.auth.name != ctx.obj.auth.email:
content = f"{ctx.obj.auth.name}, {ctx.obj.auth.email}"

details = {"Account": f"{content} {render_email_note(ctx.obj.auth)}"}
else:
details = {"Account": f"Offline - {os.getenv('SAFETY_DB_DIR')}"}

if ctx.obj.project.id:
details["Project"] = ctx.obj.project.id
Expand Down
10 changes: 8 additions & 2 deletions safety/scan/finder/handlers.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from abc import ABC, abstractmethod
import os
from pathlib import Path
from types import MappingProxyType
from typing import Dict, List, Optional, Tuple
Expand Down Expand Up @@ -49,12 +50,17 @@ def __init__(self) -> None:

def download_required_assets(self, session):
from safety.safety import fetch_database

SAFETY_DB_DIR = os.getenv("SAFETY_DB_DIR")

db = False if SAFETY_DB_DIR is None else SAFETY_DB_DIR


fetch_database(session=session, full=False, db=False, cached=True,
fetch_database(session=session, full=False, db=db, cached=True,
telemetry=True, ecosystem=Ecosystem.PYTHON,
from_cache=False)

fetch_database(session=session, full=True, db=False, cached=True,
fetch_database(session=session, full=True, db=db, cached=True,
telemetry=True, ecosystem=Ecosystem.PYTHON,
from_cache=False)

Expand Down
4 changes: 4 additions & 0 deletions safety/scan/validators.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@

import os
from pathlib import Path
from typing import Optional, Tuple
import typer
Expand Down Expand Up @@ -42,6 +43,9 @@ def fail_if_not_allowed_stage(ctx: typer.Context):
stage = ctx.obj.auth.stage
auth_type: AuthenticationType = ctx.obj.auth.client.get_authentication_type()

if os.getenv("SAFETY_DB_DIR"):
return

if not auth_type.is_allowed_in(stage):
raise typer.BadParameter(f"'{auth_type.value}' auth type isn't allowed with " \
f"the '{stage}' stage.")
Expand Down
29 changes: 29 additions & 0 deletions tests/scan/test_file_handlers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
import os
import pytest
from unittest.mock import Mock, patch
from safety.scan.finder.handlers import PythonFileHandler

@patch('safety.safety.fetch_database')
def test_download_required_assets(mock_fetch_database):
handler = PythonFileHandler()
session = Mock()

os.environ["SAFETY_DB_DIR"] = "/path/to/db"
handler.download_required_assets(session)

_, kwargs = mock_fetch_database.call_args

assert kwargs['db'] == "/path/to/db"

@patch('safety.safety.fetch_database')
def test_download_required_assets_no_db_dir(mock_fetch_database):
handler = PythonFileHandler()
session = Mock()

if "SAFETY_DB_DIR" in os.environ:
del os.environ["SAFETY_DB_DIR"]
handler.download_required_assets(session)

_, kwargs = mock_fetch_database.call_args

assert kwargs['db'] == False

0 comments on commit a1692d5

Please sign in to comment.