Skip to content

Commit

Permalink
v0.1.1
Browse files Browse the repository at this point in the history
v0.1.1
  • Loading branch information
浅梦 authored Apr 7, 2020
2 parents c90bffd + 83e4442 commit 732d8a1
Show file tree
Hide file tree
Showing 10 changed files with 113 additions and 39 deletions.
26 changes: 26 additions & 0 deletions .github/ISSUE_TEMPLATE/bug_report.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
---
name: Bug report
about: Create a report to help us improve
title: ''
labels: ''
assignees: ''

---

**Describe the bug(问题描述)**
A clear and concise description of what the bug is.

**To Reproduce(复现步骤)**
Steps to reproduce the behavior:
1. Go to '...'
2. Click on '....'
3. Scroll down to '....'
4. See error

**Operating environment(运行环境):**
- python version [e.g. 3.4, 3.6]
- tensorflow version [e.g. 1.4.0, 1.12.0]
- deepmatch version [e.g. 0.1.1,]

**Additional context**
Add any other context about the problem here.
20 changes: 20 additions & 0 deletions .github/ISSUE_TEMPLATE/feature_request.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
---
name: Feature request
about: Suggest an idea for this project
title: ''
labels: enhancement&feature request
assignees: ''

---

**Is your feature request related to a problem? Please describe.**
A clear and concise description of what the problem is. Ex. I'm always frustrated when [...]

**Describe the solution you'd like**
A clear and concise description of what you want to happen.

**Describe alternatives you've considered**
A clear and concise description of any alternative solutions or features you've considered.

**Additional context**
Add any other context or screenshots about the feature request here.
20 changes: 20 additions & 0 deletions .github/ISSUE_TEMPLATE/question.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
---
name: Question
about: Ask any question ~
title: ''
labels: question
assignees: ''

---
Please refer to the [FAQ](https://deepmatch.readthedocs.io/en/latest/FAQ.html) in doc and search for the [related issues](https://github.com/shenweichen/DeepCTR/issues) before you ask the question.

**Describe the question(问题描述)**
A clear and concise description of what the question is.

**Additional context**
Add any other context about the problem here.

**Operating environment(运行环境):**
- python version [e.g. 3.6]
- tensorflow version [e.g. 1.4.0,]
- deepmatch version [e.g. 0.1.1,]
10 changes: 10 additions & 0 deletions CONTRIBUTING.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
This project is under development and we need developers to participate in.

If you

- familiar with and interested in matching algorithms
- familiar with tensorflow
- have spare time to learn and develop
- familiar with git

please send a brief introduction of your background and experience to [email protected], welcome to join us!
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
[![Disscussion](https://img.shields.io/badge/chat-wechat-brightgreen?style=flat)](./README.md#disscussiongroup)
[![License](https://img.shields.io/github/license/shenweichen/deepmatch.svg)](https://github.com/shenweichen/deepmatch/blob/master/LICENSE)

DeepMatch is a deep matching model library for recommendations, advertising, and search. It's easy to **train models** and to **export representation vectors** for user and item which can be used for **ANN search**.You can use any complex model with `model.fit()`and `model.predict()` .
DeepMatch is a deep matching model library for recommendations & advertising. It's easy to **train models** and to **export representation vectors** for user and item which can be used for **ANN search**.You can use any complex model with `model.fit()`and `model.predict()` .

Let's [**Get Started!**](https://deepmatch.readthedocs.io/en/latest/Quick-Start.html)

Expand Down
2 changes: 1 addition & 1 deletion deepmatch/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from .utils import check_version

__version__ = '0.1.0'
__version__ = '0.1.1'
check_version(__version__)
66 changes: 32 additions & 34 deletions deepmatch/layers/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,53 +153,51 @@ def compute_output_shape(self, input_shape):


class CapsuleLayer(Layer):
def __init__(self, input_units, out_units, max_len, k_max, iteration=3,
weight_initializer=RandomNormal(stddev=1.0), **kwargs):
self.input_units = input_units # E1
self.out_units = out_units # E2
def __init__(self, input_units, out_units, max_len, k_max, iteration_times=3,
initializer=RandomNormal(stddev=1.0), **kwargs):
self.input_units = input_units
self.out_units = out_units
self.max_len = max_len
self.k_max = k_max
self.iteration = iteration
self.weight_initializer = weight_initializer
self.iteration_times = iteration_times
self.initializer = initializer
super(CapsuleLayer, self).__init__(**kwargs)

def build(self, input_shape):
self.B_matrix = self.add_weight(shape=[1, self.k_max, self.max_len], initializer=self.weight_initializer,
trainable=False, name="B", dtype=tf.float32) # [1,K,H]
self.S_matrix = self.add_weight(shape=[self.input_units, self.out_units], initializer=self.weight_initializer,
name="S", dtype=tf.float32)
self.routing_logits = self.add_weight(shape=[1, self.k_max, self.max_len], initializer=self.initializer,
trainable=False, name="B", dtype=tf.float32)
self.bilinear_mapping_matrix = self.add_weight(shape=[self.input_units, self.out_units],
initializer=self.initializer,
name="S", dtype=tf.float32)
super(CapsuleLayer, self).build(input_shape)

def call(self, inputs, **kwargs): # seq_len:[B,1]
low_capsule, seq_len = inputs
B = tf.shape(low_capsule)[0]
seq_len_tile = tf.tile(seq_len, [1, self.k_max]) # [B,K]

for i in range(self.iteration):
mask = tf.sequence_mask(seq_len_tile, self.max_len) # [B,K,H]
pad = tf.ones_like(mask, dtype=tf.float32) * (-2 ** 16 + 1) # [B,K,H]
B_tile = tf.tile(self.B_matrix, [B, 1, 1]) # [B,K,H]
B_mask = tf.where(mask, B_tile, pad)
W = tf.nn.softmax(B_mask) # [B,K,H]
low_capsule_new = tf.tensordot(low_capsule, self.S_matrix, axes=1) # [B,H,E2]
high_capsule_tmp = tf.matmul(W, low_capsule_new) # [B,K,E2]
high_capsule = squash(high_capsule_tmp) # [B,K,E2]

# ([B,K,E2], [B,H,E2]->[B,E2,H])->[B,K,H]->[1,K,H]
B_delta = tf.reduce_sum(
tf.matmul(high_capsule, tf.transpose(low_capsule_new, perm=[0, 2, 1])),
def call(self, inputs, **kwargs):
behavior_embddings, seq_len = inputs
batch_size = tf.shape(behavior_embddings)[0]
seq_len_tile = tf.tile(seq_len, [1, self.k_max])

for i in range(self.iteration_times):
mask = tf.sequence_mask(seq_len_tile, self.max_len)
pad = tf.ones_like(mask, dtype=tf.float32) * (-2 ** 32 + 1)
routing_logits_with_padding = tf.where(mask, tf.tile(self.routing_logits, [batch_size, 1, 1]), pad)
weight = tf.nn.softmax(routing_logits_with_padding)
behavior_embdding_mapping = tf.tensordot(behavior_embddings, self.bilinear_mapping_matrix, axes=1)
Z = tf.matmul(weight, behavior_embdding_mapping)
interet_capsules = squash(Z)
delta_routing_logits = tf.reduce_sum(
tf.matmul(interet_capsules, tf.transpose(behavior_embdding_mapping, perm=[0, 2, 1])),
axis=0, keep_dims=True
) # [1,K,H]
self.B_matrix.assign_add(B_delta)
high_capsule = tf.reshape(high_capsule, [-1, self.k_max, self.out_units])
return high_capsule
)
self.routing_logits.assign_add(delta_routing_logits)
interet_capsules = tf.reshape(interet_capsules, [-1, self.k_max, self.out_units])
return interet_capsules

def compute_output_shape(self, input_shape):
return (None, self.k_max, self.out_units)


def squash(inputs):
vec_squared_norm = tf.reduce_sum(tf.square(inputs), axis=-1, keep_dims=True)
scalar_factor = vec_squared_norm / (1 + vec_squared_norm) / tf.sqrt(vec_squared_norm + 1e-9)
vec_squashed = scalar_factor * inputs # element-wise
scalar_factor = vec_squared_norm / (1 + vec_squared_norm) / tf.sqrt(vec_squared_norm + 1e-8)
vec_squashed = scalar_factor * inputs
return vec_squashed
2 changes: 1 addition & 1 deletion deepmatch/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@

import tensorflow as tf
from tensorflow.python.keras import backend as K
from tensorflow.python.keras._impl.keras.layers import Lambda
from tensorflow.python.keras.layers import Lambda

def recall_N(y_true, y_pred, N=50):
return len(set(y_pred[:N]) & set(y_true)) * 1.0 / len(y_true)
Expand Down
2 changes: 1 addition & 1 deletion docs/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
# The short X.Y version
version = ''
# The full version, including alpha/beta/rc tags
release = '0.1.0'
release = '0.1.1'


# -- General configuration ---------------------------------------------------
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

setuptools.setup(
name="deepmatch",
version="0.1.0",
version="0.1.1",
author="Weichen Shen",
author_email="[email protected]",
description="Deep matching model library for recommendations, advertising, and search. It's easy to train models and to **export representation vectors** for user and item which can be used for **ANN search**.",
Expand Down

0 comments on commit 732d8a1

Please sign in to comment.