Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[#3] 입력 데이터를 처리하기 위한 namedtuple 구현 #6

Open
wants to merge 10 commits into
base: main
Choose a base branch
from
139 changes: 139 additions & 0 deletions CATS/inputs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,139 @@
from collections import namedtuple
from typing import Literal
DEFAULT_GROUP_NAME = "default_group"


class SparseFeat(namedtuple('SparseFeat',
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

가이드를 드리기위한 코멘트 (임시)

['name', 'vocabulary_size', 'embedding_dim', 'use_hash', 'dtype', 'embedding_name',
'group_name'])):
__slots__ = ()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

__slots__의 용처와 사용법이 어떻게 되는지 설명해 주실 수 있으실까요?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

기존 클래스는 instance를 생성하면 dictionary에 키와 값을 저장합니다. dictionary를 사용했을 때, 장점은 멤버변수를 동적으로 추가하기 쉽다는 것입니다. 하지만 단점은 멤버변수가 동적이기 때문에 잘못된 멤버변수를 사용하여 오류가 날 수 있고 객체 마다 별도의 메모리를 관리해야하기 때문에 메모리 오버헤드가 큽니다.
__slots__는 __dict__와 다르게 static한 멤버변수를 가지기 때문에 고정된 메모리 공간을 사용하여 효율적인 메모리 관리가 가능합니다. 그렇기 때문에 객체 내 저장될 데이터의 수가 많다면 __slots__를 사용하여 객체를 관리하는 것이 효율적입니다.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

�그럼 nametuple 역시 불변 객체를 만드는데요. 명시적으로 __slots__()을 적어주신 추가적인 이유가 있으신지 궁금합니다!

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

__slots__를 ()로 선언하게되면 동적으로 속성값을 변경할 수 없게되며 메모리 효율이 증가합니다. 그렇기 때문에 현재 객체내 멤버변수값을 동적으로 바꿀 필요가 없는경우 메모리 효율을 위해 slots=()를 작성하였습니다.


def __new__(cls, name: str, vocabulary_size: int, embedding_dim=4, use_hash=False, dtype="int32", embedding_name=None,
group_name=DEFAULT_GROUP_NAME):
"""
Returns information about a single categorical data.
:param name: feature's name
:param vocabulary_size: input category name
:param embedding_dim: Converted embedding's dimension
:param use_hash: whether to use hash
:param dtype: data's type
:param embedding_name: embedding's name
:param group_name: group's name
"""
if embedding_name is None:
embedding_name = name
elif embedding_dim == 'auto':
embedding_dim = 6 * int(pow(vocabulary_size, 0.25))
if use_hash:
raise NotImplementedError("Feature hashing is not supported in PyTorch version. "
"Please use TensorFlow or disable hashing.")
return super(SparseFeat, cls).__new__(cls, name, vocabulary_size, embedding_dim, use_hash, dtype,
embedding_name, group_name)

def __hash__(self):
"""
Determines the hash value based on the name.
:return: self.name's hash
"""
return self.name.__hash__()


class VarLenSparseFeat(namedtuple('VarLenSparseFeat',
['sparsefeat', 'maxlen', 'combiner', 'length_name'])):
__slots__ = ()

def __new__(cls, sparsefeat: SparseFeat, maxlen: int, combiner: Literal['sum', 'mean', 'max'], length_name=None):
"""

:param sparsefeat: a single categorical data's info namedtuple
:param maxlen: maximum categories length
:param combiner: combining method for features ('sum', 'mean', 'max')
:param length_name: feature length name
"""
return super(VarLenSparseFeat, cls).__new__(cls, sparsefeat, maxlen, combiner, length_name)

@property
def name(self):
"""
VarLenSparseFeat's name
:return: sparsefeat.name
"""
return self.sparsefeat.name

@property
def vocabulary_size(self):
"""
VarLenSparseFeat's vocabulary size
:return: sparsefeat.vocabulary_size
"""
return self.sparsefeat.vocabulary_size

@property
def embedding_dim(self):
"""
VarLenSparseFeat's embedding dimension
:return: sparsefeat.embedding_dim
"""
return self.sparsefeat.embedding_dim

@property
def use_hash(self):
"""
whether to use hash
:return: sparsefeat.use_hash
"""
return self.sparsefeat.use_hash

@property
def dtype(self):
"""
data's type
:return: sparsefeat.dtype
"""
return self.sparsefeat.dtype

@property
def embedding_name(self):
"""
embedding's name
:return: sparsefeat.embedding_name
"""
return self.sparsefeat.embedding_name

@property
def group_name(self):
"""
group's name
:return: sparsefeat.group_name
"""
return self.sparsefeat.group_name

def __hash__(self):
"""
Determines the hash value based on the name.
:return: self.name's hash
"""
return self.name.__hash__()


class DenseFeat(namedtuple('Dense',
['name', 'dimension', 'dtype'])):
__slots__ = ()

def __new__(cls, name: str, dimension=1, dtype="float32"):
"""
Returns information about a numeric data.
:param name: numeric data's attribute name
:param dimension: dimension number
:param dtype: data's type
"""
if dimension < 0 and not isinstance(dimension, int):
raise ValueError("dimension must bigger then 0 and must be integer ")
return super(DenseFeat, cls).__new__(cls, name, dimension, dtype)

def __hash__(self):
"""
Determines the hash value based on the name.
:return: self.name's hash
"""
return self.name.__hash__()
File renamed without changes.
File renamed without changes.
Loading