-
Notifications
You must be signed in to change notification settings - Fork 11
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
0df347e
commit 598946d
Showing
7 changed files
with
272 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -17,3 +17,4 @@ venv/ | |
*test*.py | ||
/*test*/* | ||
/*test*/**/* | ||
/bin/* |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
# flake8: noqa | ||
|
||
from .func import * | ||
from .util import * | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,191 @@ | ||
import warnings | ||
|
||
import numpy as np | ||
from vsexprtools import norm_expr | ||
from vstools import (CustomValueError, FuncExceptT, FunctionUtil, SPath, | ||
SPathLike, clip_async_render, core, fallback, vs) | ||
|
||
from .util import get_format_from_npz | ||
|
||
__all__: list[str] = [ | ||
'prepare_clip_for_npz', 'finalize_clip_from_npz', | ||
'clip_to_npz', 'npz_to_clip', | ||
] | ||
|
||
|
||
def prepare_clip_for_npz(clip: vs.VideoNode, func_except: FuncExceptT | None = None) -> vs.VideoNode: | ||
""" | ||
Prepare a clip for exporting to numpy files. | ||
This involves dithering up to 32-bit float and normalizing the UV ranges to [0, 1] if present. | ||
This should be used before exporting frames to numpy files. | ||
:param clip: The input video clip to process. | ||
:param func_except: Function returned for custom error handling. | ||
This should only be set by VS package developers. | ||
:return: The processed clip. | ||
""" | ||
|
||
return _process_clip_for_npz(clip, fallback(func_except, prepare_clip_for_npz), 'prepare') | ||
|
||
|
||
def finalize_clip_from_npz(clip: vs.VideoNode, func_except: FuncExceptT | None = None) -> vs.VideoNode: | ||
""" | ||
Finalize a clip obtained from numpy files. | ||
This involves denormalizing the UV ranges to the original range if present. | ||
This should be used after loading frames from numpy files. | ||
:param clip: The input video clip to process. | ||
:param func_except: Function returned for custom error handling. | ||
This should only be set by VS package developers. | ||
:return: The processed clip. | ||
""" | ||
|
||
return _process_clip_for_npz(clip, fallback(func_except, finalize_clip_from_npz), 'finalize') | ||
|
||
|
||
def _process_clip_for_npz(clip: vs.VideoNode, func_except: FuncExceptT | None, operation: str) -> vs.VideoNode: | ||
func = FunctionUtil(clip, fallback(func_except, _process_clip_for_npz), None, (vs.GRAY, vs.YUV), 32) | ||
|
||
if not func.chroma_planes: | ||
return func.work_clip | ||
|
||
return norm_expr(func.work_clip, 'x 0.5 +' if operation == 'prepare' else 'x 0.5 -', func.chroma_planes) | ||
|
||
|
||
def clip_to_npz(src: vs.VideoNode, out_dir: SPathLike = 'bin/') -> list[SPath]: | ||
""" | ||
Export frames from a VideoNode to numpy array files. | ||
This function is intended to be used to help with preparing training data for neural networks. | ||
The function will not overwrite existing files, | ||
and instead increments the next filename by 1. | ||
:param src: The input video clip. | ||
:param out_dir: The directory to save the numpy arrays. | ||
Default: "bin/". | ||
:return: A list of paths to the exported numpy arrays. | ||
:raises RuntimeWarning: If any frames failed to process. | ||
""" | ||
|
||
func = FunctionUtil(src, clip_to_npz, None, vs.YUV) | ||
|
||
proc_clip = func.work_clip | ||
|
||
out_dir = SPath(out_dir) | ||
out_dir.mkdir(511, True, True) | ||
|
||
next_name = max((int(f.stem) for f in out_dir.glob('*.npz')), default=0) + 1 | ||
|
||
if not (total_frames := len(proc_clip)): | ||
return [] | ||
|
||
try: | ||
from tqdm import tqdm | ||
pbar = tqdm(total=total_frames, unit='frame', desc=f'Dumping numpy arrays to {out_dir}...') | ||
except ImportError: | ||
pbar = None | ||
|
||
def _update_progress(filename: str | None = None): | ||
if not pbar: | ||
return | ||
|
||
pbar.update(1) | ||
|
||
if filename: | ||
pbar.set_postfix({'Current file': filename}, refresh=True) | ||
|
||
exported_files = [] | ||
|
||
def _process_frame(n: int, frame: vs.VideoFrame): | ||
nonlocal next_name | ||
|
||
try: | ||
frame_data = np.array([( | ||
np.asarray(frame[0]), | ||
np.asarray(frame[1]) if frame.format.num_planes > 1 else None, | ||
np.asarray(frame[2]) if frame.format.num_planes > 1 else None | ||
)], dtype=[('Y', object), ('U', object), ('V', object)]) | ||
|
||
filename = f'{next_name:05d}.npz' | ||
file_path = out_dir / filename | ||
|
||
np.save(file_path, frame_data, allow_pickle=True) | ||
|
||
next_name += 1 | ||
|
||
_update_progress(filename) | ||
|
||
exported_files.append(file_path) | ||
|
||
return filename | ||
except Exception as e: | ||
print(f'Error processing frame {n} ({str(e)})') | ||
|
||
_update_progress() | ||
|
||
return None | ||
|
||
proc_frames = clip_async_render(proc_clip, callback=_process_frame) | ||
|
||
if pbar: | ||
pbar.close() | ||
|
||
if failed_frames := [f for f in proc_frames if f is None]: | ||
warnings.warn( | ||
f'export_frames_to_npz: {len(failed_frames)} frames failed to process ({failed_frames}).', | ||
RuntimeWarning | ||
) | ||
|
||
return exported_files | ||
|
||
|
||
def npz_to_clip(file_paths: list[SPathLike] | SPathLike = []) -> vs.VideoNode: | ||
""" | ||
Read numpy files and convert them to a VapourSynth clip. | ||
:param file_paths: The list of numpy files to convert to a clip. | ||
If a directory is provided, all .npz files in the directory will be used. | ||
If a single file is provided, it will be used instead. | ||
:return: The clip. | ||
""" | ||
|
||
if not isinstance(file_paths, list): | ||
file_paths = SPath(file_paths) | ||
|
||
if file_paths.is_dir(): | ||
file_paths = list(file_paths.glob("*.npz")) | ||
else: | ||
file_paths = [file_paths] | ||
|
||
if not file_paths: | ||
raise CustomValueError("No files provided", npz_to_clip) | ||
|
||
file_paths = sorted(file_paths, key=lambda x: int(x.stem)) | ||
|
||
first_frame = np.load(file_paths[0], allow_pickle=True)[0] | ||
height, width = first_frame['Y'].shape | ||
|
||
format = get_format_from_npz(first_frame) | ||
|
||
blank_clip = core.std.BlankClip(None, width, height, format, length=len(file_paths), keep=True) | ||
|
||
def _read_frame(n: int, f: vs.VideoFrame) -> vs.VideoNode: | ||
loaded_frame = np.load(file_paths[n], allow_pickle=True)[0] | ||
|
||
fout = f.copy() | ||
|
||
for plane in range(f.format.num_planes): | ||
plane_data = loaded_frame[f.format.name[plane]] | ||
np.copyto(np.asarray(fout[plane]), plane_data) | ||
|
||
return fout | ||
|
||
return blank_clip.std.ModifyFrame(blank_clip, _read_frame) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,70 @@ | ||
import numpy as np | ||
from vstools import (CustomRuntimeError, FuncExceptT, InvalidVideoFormatError, | ||
core, depth, fallback, get_video_format, vs) | ||
|
||
__all__: list[str] = [ | ||
'get_format_from_npz', | ||
] | ||
|
||
|
||
def get_format_from_npz(frame_data: np.ndarray, func_except: FuncExceptT | None = None) -> vs.VideoFormat: | ||
""" | ||
Guess the format based on heuristics from the numpy array data. | ||
Input is assumed to be a numpy array with the following structure: | ||
[ | ||
('Y', np.ndarray), | ||
('U', np.ndarray | None), | ||
('V', np.ndarray | None) | ||
] | ||
If every array has the same shape, it's assumed to be YUV 4:4:4. | ||
If you output RGB data, you may have to convert it back. | ||
If the U and V arrays are None, it's assumed to be GRAY. | ||
:param frame_data: The numpy array data to guess the format from. | ||
:param func_except: Function returned for custom error handling. | ||
This should only be set by VS package developers. | ||
:return: The guessed format. | ||
""" | ||
|
||
func = fallback(func_except, get_format_from_npz) | ||
|
||
y_data, u_data, v_data = frame_data['Y'], frame_data['U'], frame_data['V'] | ||
|
||
bit_depth = 32 if y_data.dtype == np.float32 else y_data.itemsize * 8 | ||
sample_type = vs.FLOAT if bit_depth == 32 else vs.INTEGER | ||
|
||
if (u_data is None) != (v_data is None): | ||
raise CustomRuntimeError('U and V planes must both be present or both be None', func) | ||
|
||
if u_data is None or v_data is None: | ||
return get_video_format( | ||
depth(core.std.BlankClip(format=vs.GRAY8, keep=True), bit_depth, sample_type=sample_type) | ||
) | ||
|
||
y_shape, u_shape, v_shape = y_data.shape, u_data.shape, v_data.shape | ||
|
||
if u_shape != v_shape: | ||
raise InvalidVideoFormatError('U and V planes must have the same shape', func) | ||
|
||
if y_shape == u_shape: | ||
subsampling = vs.YUV444P8 | ||
elif u_shape[0] == y_shape[0] and u_shape[1] == y_shape[1] // 2: | ||
subsampling = vs.YUV422P8 | ||
elif u_shape[0] == y_shape[0] // 2 and u_shape[1] == y_shape[1] // 2: | ||
subsampling = vs.YUV420P8 | ||
else: | ||
raise InvalidVideoFormatError(f'Unknown subsampling! {y_shape=}, {u_shape=}, {v_shape=}', func) | ||
|
||
try: | ||
# TODO: Figure out smarter way to get the exact format directly | ||
# If only a str overload existed for get_video_format... | ||
return get_video_format( | ||
depth(core.std.BlankClip(format=subsampling, keep=True), bit_depth, sample_type=sample_type) | ||
) | ||
except AttributeError: | ||
raise InvalidVideoFormatError(f'Unsupported format: {subsampling=} {sample_type=} {bit_depth=}', func) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -8,3 +8,5 @@ vskernels>=2.4.1 | |
vsmasktools>=1.1.2 | ||
vsrgtools>=1.5.1 | ||
stgfunc>=3.1.0 | ||
numpy>=2.1.1 | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters