Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update deep_q_network.py #93

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 50 additions & 0 deletions deep_q_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,11 @@
import wrapped_flappy_bird as game
import random
import numpy as np
import numpy as np
import cvlib as cv
from collections import deque

BOUNDARY = 0.038 # 表示连续4帧中,首尾帧中小鸟y轴变化距离与图片高的比值,高于这个比值则需要惩罚
GAME = 'bird' # the name of the game being played for log files
ACTIONS = 2 # number of valid actions
GAMMA = 0.99 # decay rate of past observations
Expand All @@ -21,6 +24,51 @@
BATCH = 32 # size of minibatch
FRAME_PER_ACTION = 1

def get_highest_conf_bird_center(image):
"""
处理图像并获取置信度最高的 'bird' 标签的中心坐标。
如果没有找到 'bird' 标签,返回 (0, 0)。
"""
# 调整图像大小为 480x480,转换为 BGR 确保图像数据类型是 uint8
resized_image = cv2.resize(image, (480, 480))
bgr_image = cv2.cvtColor(resized_image, cv2.COLOR_RGB2BGR)
bgr_image = bgr_image.astype(np.uint8)

# 进行物体检测
bbox, label, conf = cv.detect_common_objects(bgr_image)

max_conf = 0
center_x, center_y = 0, 0
# 遍历检测到的物体,找到置信度最高的 'bird'
for i in range(len(label)):
if label[i] == 'bird':
if conf[i] > max_conf:
max_conf = conf[i]
# 获取边界框并计算中心坐标
x1, y1, x2, y2 = bbox[i]
center_x = (x1 + x2) / 2
center_y = (y1 + y2) / 2

return center_y

def new_reward(first_frame, last_frame):
"""
计算两帧图像中 'bird' 标签的纵坐标变化。
并根据变化率,得到具体奖惩数值。
"""
# 获取 first_frame last_frame 中置信度最高的 'bird' 中心纵坐标
center_y_first = get_highest_conf_bird_center(first_frame)
center_y_last = get_highest_conf_bird_center(last_frame)

# 计算纵坐标变化,取绝对值并除以 480
vertical_change = abs(center_y_first - center_y_last) / 480
if vertical_change >= BOUNDARY: # 小鸟移动速度过快,实施惩罚与速度正相关
reward = -vertical_change
else: # 鼓励小鸟平稳缓慢移动,实施奖励与速度负相关
reward = BOUNDARY - vertical_change

return reward

def weight_variable(shape):
initial = tf.truncated_normal(shape, stddev = 0.01)
return tf.Variable(initial)
Expand Down Expand Up @@ -136,6 +184,8 @@ def trainNetwork(s, readout, h_fc1, sess):

# run the selected action and observe next state and reward
x_t1_colored, r_t, terminal = game_state.frame_step(a_t)
# 细化奖励
r_t = new_reward(x_t1_colored, s_t[:, :, -1])
x_t1 = cv2.cvtColor(cv2.resize(x_t1_colored, (80, 80)), cv2.COLOR_BGR2GRAY)
ret, x_t1 = cv2.threshold(x_t1, 1, 255, cv2.THRESH_BINARY)
x_t1 = np.reshape(x_t1, (80, 80, 1))
Expand Down