Skip to content

Commit

Permalink
added env variables to docker-compose.yml for better changablity. Add…
Browse files Browse the repository at this point in the history
…ed uniq folder for each api request to save and edit files.
  • Loading branch information
Dan committed Aug 1, 2023
1 parent 300816b commit 9ce4a29
Show file tree
Hide file tree
Showing 4 changed files with 70 additions and 44 deletions.
5 changes: 1 addition & 4 deletions Dockerfile
Original file line number Diff line number Diff line change
@@ -1,14 +1,11 @@
FROM python:3.10

ENV EXECUTION_PROVIDER=CPU
ENV TEMP_FRAME_FORMAT=png
ENV TEMP_FRAME_QUALITY=0
ENV OUTPUT_VIDEO_QUALITY=0

RUN apt-get update && apt-get install -y ffmpeg
RUN pip install --upgrade pip
#RUN git clone https://github.com/s0md3v/roop.git /roop
RUN ls

RUN git clone https://github.com/danikhani/roop.git -b docker /roop
RUN pip install -r roop/requirements-docker.txt

Expand Down
94 changes: 58 additions & 36 deletions api.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,23 @@
import sys
import os
import shutil
from fastapi import FastAPI, File, UploadFile, Depends, Response
import uuid

from fastapi import FastAPI, File, UploadFile, Depends, Response, BackgroundTasks
from fastapi.responses import FileResponse
import uvicorn
from pydantic import BaseModel
from typing import Optional, Literal

import roop.globals
from roop.core import decode_execution_providers, suggest_execution_threads, limit_resources, start, \
get_frame_processors_modules, suggest_execution_providers
get_frame_processors_modules
from roop.utilities import resolve_relative_path


TEMP_FILES_PATH: str
app = FastAPI()


class RoopModel(BaseModel):
frame_processor: Optional[list] = ['face_swapper']
keep_fps: Optional[bool] = False
Expand All @@ -25,23 +28,17 @@ class RoopModel(BaseModel):
reference_frame_number: Optional[int] = 0
similar_face_distance: Optional[float] = 0.85
output_video_encoder: Optional[Literal['libx264', 'libx265', 'libvpx-vp9', 'h264_nvenc', 'hevc_nvenc']] = 'libx264'
max_memory: Optional[int] = 0
execution_threads: Optional[int] = suggest_execution_threads()


@app.post("/start_roop")
async def image_file(
src_file: UploadFile = File(...),
target_file: UploadFile = File(...),
roop_parameters: RoopModel = Depends()
):
# Removing the folder and its content if it already exists
saving_path = resolve_relative_path('../workdir/')
if os.path.exists(saving_path):
shutil.rmtree(saving_path)
os.makedirs(saving_path)

#Get execution provider from env.
def set_init_global_params():
global TEMP_FILES_PATH
TEMP_FILES_PATH = resolve_relative_path('../tmp/')
if os.path.exists(TEMP_FILES_PATH):
shutil.rmtree(TEMP_FILES_PATH)
os.makedirs(TEMP_FILES_PATH)

# Get execution provider from env.
execution_provider = os.getenv('EXECUTION_PROVIDER')
print("execution provider is set to {}".format(execution_provider))
if execution_provider == 'CUDA':
Expand All @@ -51,26 +48,18 @@ async def image_file(
else:
execution_provider_list = ['cpu']
print(execution_provider_list)
roop.globals.execution_providers = decode_execution_providers(execution_provider_list)

roop.globals.temp_frame_format = os.getenv('TEMP_FRAME_FORMAT')
roop.globals.temp_frame_quality = os.getenv('TEMP_FRAME_QUALITY')
roop.globals.output_video_quality = os.getenv('OUTPUT_VIDEO_QUALITY')

# setting paths
src_saving_path_complete = os.path.join(saving_path, src_file.filename)
target_saving_path_complete = os.path.join(saving_path, target_file.filename)
output_saving_path_complete = os.path.join(saving_path, 'output_' + target_file.filename)
with open(src_saving_path_complete, "wb+") as file_object:
file_object.write(src_file.file.read())
with open(target_saving_path_complete, 'wb+') as file_obj:
file_obj.write(target_file.file.read())

roop.globals.source_path = src_saving_path_complete
roop.globals.target_path = target_saving_path_complete
roop.globals.output_path = output_saving_path_complete
roop.globals.max_memory = os.getenv('MAX_MEMORY')

# Since API starts the roop. It is always headless
roop.globals.headless = True


def set_request_params(roop_parameters: RoopModel):
# Setting other roop parameters
roop.globals.frame_processors = roop_parameters.frame_processor
roop.globals.keep_fps = roop_parameters.keep_fps
Expand All @@ -81,20 +70,53 @@ async def image_file(
roop.globals.reference_frame_number = roop_parameters.reference_frame_number
roop.globals.similar_face_distance = roop_parameters.similar_face_distance
roop.globals.output_video_encoder = roop_parameters.output_video_encoder
roop.globals.max_memory = roop_parameters.max_memory
roop.globals.execution_providers = decode_execution_providers(execution_provider_list)
roop.globals.execution_threads = roop_parameters.execution_threads


def remove_request_dir(path:str):
if os.path.exists(path):
shutil.rmtree(path)


@app.post("/start_roop")
async def image_file(
background_tasks: BackgroundTasks,
src_file: UploadFile = File(...),
target_file: UploadFile = File(...),
roop_parameters: RoopModel = Depends()
):
# Making a uniq folder for each request.
request_uuid_str = str(uuid.uuid4())
request_dir_path = os.path.join(TEMP_FILES_PATH, request_uuid_str)
if not os.path.exists(request_dir_path):
os.makedirs(request_dir_path)
# Setting request parameters.
set_request_params(roop_parameters)

# uniq paths for the request
src_saving_path_complete = os.path.join(request_dir_path, src_file.filename)
target_saving_path_complete = os.path.join(request_dir_path, target_file.filename)
output_saving_path_complete = os.path.join(request_dir_path, 'output_' + target_file.filename)
with open(src_saving_path_complete, "wb+") as file_object:
file_object.write(src_file.file.read())
with open(target_saving_path_complete, 'wb+') as file_obj:
file_obj.write(target_file.file.read())

roop.globals.source_path = src_saving_path_complete
roop.globals.target_path = target_saving_path_complete
roop.globals.output_path = output_saving_path_complete

for frame_processor in get_frame_processors_modules(roop.globals.frame_processors):
if not frame_processor.pre_check():
sys.exit()
limit_resources()
start()

# Background task to remove the output after it has been sent in response.
background_tasks.add_task(remove_request_dir, path=request_dir_path)
return FileResponse(roop.globals.output_path)


if __name__ == "__main__":
if os.getenv('EXECUTION_PROVIDER') is None:
print('Env variable for execution provider is not set. Setting default to CPU.')
os.environ['EXECUTION_PROVIDER'] = 'CPU'
print(os.getenv('EXECUTION_PROVIDER'))
set_init_global_params()
uvicorn.run(app, host="0.0.0.0", port=8000)
7 changes: 6 additions & 1 deletion docker-compose.yml
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,17 @@ version: '3.8'

services:
roop-api:
environment:
- TEMP_FRAME_FORMAT=png
- TEMP_FRAME_QUALITY=0
- OUTPUT_VIDEO_QUALITY=0
- MAX_MEMORY=0
build:
context: .
dockerfile: Dockerfile
restart: always
volumes:
- roop-data:/roop
- roop-data:/roop/models
ports:
- 8000:8000

Expand Down
8 changes: 5 additions & 3 deletions roop/predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,14 @@
THREAD_LOCK = threading.Lock()
MAX_PROBABILITY = 0.85

rel_path_to_nsfw_model = resolve_relative_path('../models/nsfw2')

def get_predictor() -> Model:
global PREDICTOR

with THREAD_LOCK:
if PREDICTOR is None:
PREDICTOR = opennsfw2.make_open_nsfw_model(weights_path='/roop/models/nsfw2')
PREDICTOR = opennsfw2.make_open_nsfw_model(weights_path=rel_path_to_nsfw_model)
return PREDICTOR


Expand All @@ -36,9 +38,9 @@ def predict_frame(target_frame: Frame) -> bool:


def predict_image(target_path: str) -> bool:
return opennsfw2.predict_image(image_path=target_path,weights_path='/roop/models/nsfw2') > MAX_PROBABILITY
return opennsfw2.predict_image(image_path=target_path,weights_path=rel_path_to_nsfw_model) > MAX_PROBABILITY


def predict_video(target_path: str) -> bool:
_, probabilities = opennsfw2.predict_video_frames(video_path=target_path, weights_path='/roop/models/nsfw2', frame_interval=100)
_, probabilities = opennsfw2.predict_video_frames(video_path=target_path, weights_path=rel_path_to_nsfw_model, frame_interval=100)
return any(probability > MAX_PROBABILITY for probability in probabilities)

0 comments on commit 9ce4a29

Please sign in to comment.