Re3网络的训练过程


Re3: Real-Time Recurrent Regression Networks for Visual Tracking of Generic Objects

Re3论文:https://arxiv.org/pdf/1705.06368.pdf

Re3代码:https://gitlab.com/danielgordon10/re3-tensorflow

作者主页:https://homes.cs.washington.edu/~xkcd/index.html


按代码,画一个详细的网络结构:

Traning details in README

training:

  • batch_cache.py - A local data server useful for reading in images from disc onto RAM. The images will be sent over a pipe to the training process.
  • caffe_to_tf.py - The script used to port Caffe weights to Tensorflow.
  • read_gt.py - Reads in data stored in a variety of different formats into a common(ish) format between all datasets.
  • test_net.py - Runs through a provided sequence of images, tracking and computing IOUs on all ground truth frames. It can also make videos of the results.
  • unrolled_solver.py - The big, bad training file. This unwieldy piece of code encapsulates much of the complexity involved in training Re3 in terms of data management, image preprocessing, and parallelization. It also contains lots of code for logging and debugging that I found immensely useful. Look through the command-line arguments and constants declared at the top for more knobs to tune during training. If you kill a training run with Ctrl-c, it will save the current weights (which I also found very useful) before it extits.

The Readme file in /re3-tensorflow-master/training is also important.


ILSVRC dataset

这里有个错误,我将ILSVRC中的 CLS-LOC 当做 DET 了,不过影响应该不大

VID

ILSVRC2015,下载地址:http://bvisionweb1.cs.unc.edu/ilsvrc2015/download-videos-3j16.php#vid

DET

因为没有找到ILSVRC2015DET(imagenet暂时打不开),所以这里就用ILSVRC2012代替,下载地址:http://academictorrents.com/collection/imagenet-lsvrc-2015

我下载的数据集中,Data/DET/train下面有1000个文件夹(代表1000类物体),每个文件夹中有1300张图片,但是对应的Annotations/DET/train下面对应的文件夹的xml文件不足1300个(所有xml文件共544546个,val的图片全都有对应的xml,也是50000个),意味着有的图片没有annotation信息,所以应该按照Annotations中的信息去找对应的图片进行训练和验证。

解压脚本:

下载的DET/train里是1000个压缩文件,需要解压:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
import os

dataset_path = '/home/ubuntu/Downloads/dataset/ILSVRC/Data/DET/train'

file_list = os.listdir(dataset_path)

for i, tar_file in enumerate(file_list):

if len(tar_file.split('.')) != 2 or tar_file.split('.')[1] != 'tar':
continue

tar_file = dataset_path + '/' + tar_file
tar_dir = tar_file.split('.')[0] # make dir before extract

try:
os.makedirs(tar_dir)
except FileExistsError:
pass

tar_command = 'tar -xvf %s -C %s' % (tar_file, tar_dir)
print('-----------------------------%d-----------------------------' % (i))
print(tar_command)
os.system(tar_command)

其中:

1
tar -xvf a.tar -C b/c

表示将a.tar解压到b/c目录下

VID/make_label_files.py

re3-tensorflow-master/training/datasets/imagenet_video/make_label_files.py生成训练数据。

ln -s生成数据集的软链接到程序的指定位置。

如果用PyCharm打开整个目录作为一个project,会自动加载整个数据集,当数据集非常大时,造成IDE卡顿,可以在project栏将这个数据集链接给忽略掉(右键点击->mark->exclude)。

代码中的知识点

xml.etree.ElementTree

python中用来解析xml的包。

一个xml文件中的内容:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
<annotation>
<folder>ILSVRC2015_VID_train_0000/ILSVRC2015_train_00000000</folder>
<filename>000000</filename>
<source>
<database>ILSVRC_2015</database>
</source>
<size>
<width>1280</width>
<height>720</height>
</size>
<object>
<trackid>0</trackid>
<name>n01674464</name>
<bndbox>
<xmax>1050</xmax>
<xmin>323</xmin>
<ymax>428</ymax>
<ymin>216</ymin>
</bndbox>
<occluded>1</occluded>
<generated>0</generated>
</object>
</annotation>

用以下方式读入数据:

1
2
3
4
import xml.etree.ElementTree as ET

# 获取 ElementTree 对象
xml_obj = ET.parse('abc.xml')

获取的xml_obj是一个特殊类型(Element object),还看不到任何具体的东西。

1
2
for obj in xml_obj.findall('object'):
cls = obj.find('name').text

.findall(tag)用来寻找所有符合要求的tag的Element,所以返回的是一个list,里面元素的类型都是Element object,.find(tag)只返回符合要求的第一个Element。

obj.find(‘name’).text返回一个字符串’n01674464’。

numpy.lexsort

参考numpy.lexsort:

1
2
3
4
5
>>> a = [1,5,1,4,3,4,4] # First column
>>> b = [9,4,0,4,0,2,1] # Second column
>>> ind = np.lexsort((b,a)) # Sort by a, then by b
>>> print(ind)
[2 0 4 6 5 3 1]

主要按a排序,相同元素再按照对应位置的b元素大小排序。注意参数a的位置在b的后面!

总结

make_label_files.py文件将生成的训练和验证数据以ndarray格式存放在同级目录下的labels/train(val)/labels.npy文件中。其中每一行的数据内容如下:

1
[xmin, ymin, xmax, ymax, vv, trackId, imNum, classInd, occl]

xmin, ymin, xmax, ymax是box左上角和右下角的坐标,vv表示视频的序号,trackId表示一个视频内跟踪物体的ID,imNum表示图片数(训练数据或验证数据的第X张图片),classInd表示类别,用一组特定的序号表示某一类(一共30类,取值为1 - 30),occl表示在这张图片上该物体有没有被遮挡。npy中数据按以下方式排列:

Reorder by video_id (vv), then track_id (trackId), then video image number (imNum) so all labels for a single track are next to each other.This only matters if a single image could have multiple tracks.

DET/make_label_files.py

和上面一样添加软链接,在IDE中忽略该文件夹节省加载时间。

根据我的数据集存放情况,将:

1
wildcard = '/*/*/' if label_type == 'train' else '/'

改为:

1
wildcard = '/*/' if label_type == 'train' else '/'

总结

只选择box面积大于等于图片面积0.01倍的图片,将box信息按[xmin, ymin, xmax, ymax, num]存储在train/val对应文件夹下labels.npy的文件中,文件路径存放在image_names.txt文件中

get_datasets.py

get_datasets.py文件完成的工作是从image_names.txt中提取每张图片的路径(image_paths)和make_label_files.py生成的整个npy文件信息(gt)。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
def get_data_for_dataset(dataset_name, mode):
# Implement this for each dataset.
if dataset_name == 'imagenet_video':
datadir = os.path.join(
os.path.dirname(__file__),
'datasets',
'imagenet_video')
gt = np.load(datadir + '/labels/' + mode + '/labels.npy')
image_paths = [datadir + '/' + line.strip()
for line in open(datadir + '/labels/' + mode + '/image_names.txt')]
return {
'gt' : gt,
'image_paths' : image_paths,
}

os.path.dirname(__file__)返回py文件所在目录

列表生成式

利用列表生成式可以方便的获取文件中的每一行路径数据:

1
2
image_paths = [datadir + '/' + line.strip()
for line in open(datadir + '/labels/' + mode + '/image_names.txt')]

batch_cache.py (☆key script as server)

代码中的知识点

argparse

1
2
3
4
5
import argparse

parser = argparse.ArgumentParser(description='Server for network images.')
parser.add_argument('-s', '--max_size', action='store', default=100,
dest='max_size', type=int)

用来接收命令行参数。

socketserver

一个用于进程通讯的包,在Python中,socketserver是比socket更高级别的包。

batch_cache.py文件中,创建的是一个TCP服务器:

1
handler = socketserver.TCPServer((HOST, port), BatchCacheHandler)

TCPServersocketserver包中的类,初始化参数传入了一个包含地址和端口号的元组(这里由于是本机两个进程通信,HOST= ‘localhost’,就是127.0.0.1,port默认设置为了9997),第二个参数是自己定义的类(要继承socketserver.BaseRequestHandler类,因为要用到这个类中的handle方法)。服务器与客户端建立链接后,就会用我们传入的这个类,创建一个对象(专门用来和这个客户端进行通信)。而且在传入的BatchCacheHandler类中,重载了父类的handle方法,每次接收到数据自动调用handle方法。

还有两个常用的类:ThreadingTCPServer、ForkingTCPServer

ThreadingTCPServer/TCPServer/ForkingTCPServer的区别,原理可同样引申到UDP

这三个类其实就是对接收到request请求后的不同处理方法。

TCPServer是接收到请求后执行handle方法,如果前一个的handle没有结束,那么其他的请求将不会受理,新的客户端也无法加入。

而ThreadingTCPServer和ForkingTCPServer则允许前一连接的handle未结束也可受理新的请求和连接新的客户端,区别在于前者用建立新线程的方法运行handle,后者用新进程的方法运行handle。

1
handler.serve_forever()

启动TCP服务器。

call the handle_request() orserve_forever() method of the server object to process one or many requests.

往实例中添加方法

当创建了一个实例,想要往实例中添加额外的方法时,可以这样(法一):

1
2
3
4
5
6
7
8
9
10
11
def my_print(a):
print(a)


class Student:
pass

s = Student()
s.my_print = my_print

s.my_print('123')

也可以这样(法二):

1
2
3
4
5
6
7
8
>>> def set_age(self, age): # 定义一个函数作为实例方法
... self.age = age
...
>>> from types import MethodType
>>> s.set_age = MethodType(set_age, s) # 给实例绑定一个方法
>>> s.set_age(25) # 调用实例方法
>>> s.age # 测试结果
25

法一简单,但是添加的函数无法调用self,法二可以。

handle方法调用外部的变量

几个类的定义是这样的:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
class BatchCacheHandler(SocketServer.BaseRequestHandler, object):
def handle(self):
while not self.server.shut_down:
blabla

class BatchCacheServer:

blabla

def serve(self, port):
if self.debug:
print('Server starting up')
handler = SocketServer.TCPServer((HOST, port), BatchCacheHandler)
handler.get_sample = self.get_sample
handler.batch_cache = self
handler.lock = self.data_lock
handler.shut_down = self.shut_down
handler.serve_forever()


server = BatchCacheServer(args)
server.serve(args.port)

每当一个客户端连接,就会实例化一个BatchCacheHandler对象,当收到客户端的消息时,自动执行重载的handle方法,可以在handle方法中使用self.server.xxx访问handler,也就是TCPServer实例的成员内容。所以在创建handler实例之后,又添加了几个方法和变量,注意handler.batch_cache = self,直接传递self。

在batch_cache中,TCP服务器端从客户端接收信息,用来判断链接是否保持。

threading

threading.Thread( )创建线程可以指定执行内容(target)、参数、名字等参数。

廖雪峰教程中的代码:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
import time, threading

# 新线程执行的代码:
def loop():
print('thread %s is running...' % threading.current_thread().name)
n = 0
while n < 5:
n = n + 1
print('thread %s >>> %s' % (threading.current_thread().name, n))
time.sleep(1)
print('thread %s ended.' % threading.current_thread().name)

print('thread %s is running...' % threading.current_thread().name)
t = threading.Thread(target=loop, name='LoopThread')
t.start()
t.join()
print('thread %s ended.' % threading.current_thread().name)

结果:

1
2
3
4
5
6
7
8
9
thread MainThread is running...
thread LoopThread is running...
thread LoopThread >>> 1
thread LoopThread >>> 2
thread LoopThread >>> 3
thread LoopThread >>> 4
thread LoopThread >>> 5
thread LoopThread ended.
thread MainThread ended.

t.start( ) 表示启动新线程,t.join( )表示将新线程加入到主线程中,也就是主线程阻塞,等待新线程执行完再结束(如果主线程先结束,所有新线程也随之结束)。如果在类中定义线程,可以把join放在析构函数中:

1
2
def __del__(self):
new_thread.join()

threading.Thread(target=XX, name=XX)代表着新线程只执行target参数中的内容(方法/函数)。

1
self.data_lock = threading.Lock()

threading.Lock()来实例化一个锁对象,并不需要指定锁的对象:

1
2
3
4
5
6
7
lock = threading.Lock()

lock.acquire()
try:
do_sth
finally:
lock.release()

将lock.acquire()和lock.release()之间的内容锁住。 当多个线程同时执行lock.acquire()时,只有一个线程能成功地获取锁,然后继续执行代码,其他线程就继续等待直到获得锁为止。

threading.RLock() 允许在同一线程中被多次acquire。而Lock却不允许这种情况。注意:如果使用RLock,那么acquire和release必须成对出现,即调用了n次acquire,必须调用n次的release才能真正释放所占用的琐。

如果多个线程共同对某个数据修改,则可能出现不可预料的结果,为了保证数据的正确性,需要对多个线程进行同步。

使用 Thread 对象的 Lock 和 Rlock 可以实现简单的线程同步,这两个对象都有 acquire 方法和 release 方法,对于那些需要每次只允许一个线程操作的数据,可以将其操作放到 acquire 和 release 方法之间。

traceback

异常模块,小轮子:

1
2
3
4
5
6
7
8
9
10
11
import traceback

try:
sth
except Exception as ex:
import traceback
trace = traceback.format_exc()
print(trace)
errorFile = open('error.txt', 'a+')
errorFile.write('exception in lookup_func %s\n' % str(ex))
errorFile.write(str(trace))

pickle

https://yq.aliyun.com/articles/414599

我们把变量从内存中变成可存储或传输的过程称之为序列化,在Python中叫pickling,在其他语言中也被称之为serialization,marshalling,flattening等等,都是一个意思。

序列化之后,就可以把序列化后的内容写入磁盘,或者通过网络传输到别的机器上。

反过来,把变量内容从序列化的对象重新读到内存里称之为反序列化,即unpickling。

Python提供了pickle模块来实现序列化。

首先,我们尝试把一个对象序列化并写入文件:

1
2
3
4
5
> >>> import pickle
> >>> d = dict(name='Bob', age=20, score=88)
> >>> pickle.dumps(d)
> b'\x80\x03}q\x00(X\x03\x00\x00\x00ageq\x01K\x14X\x05\x00\x00\x00scoreq\x02KXX\x04\x00\x00\x00nameq\x03X\x03\x00\x00\x00Bobq\x04u.'
>

pickle.dumps()方法把任意对象序列化成一个bytes,然后,就可以把这个bytes写入文件。或者用另一个方法pickle.dump()直接把对象序列化后写入一个file-like Object:

1
2
3
4
> >>> f = open('dump.txt', 'wb')
> >>> pickle.dump(d, f)
> >>> f.close()
>

看看写入的dump.txt文件,一堆乱七八糟的内容,这些都是Python保存的对象内部信息。

当我们要把对象从磁盘读到内存时,可以先把内容读到一个bytes,然后用pickle.loads()方法反序列化出对象,也可以直接用pickle.load()方法从一个file-like Object中直接反序列化出对象。

Pickle的问题和所有其他编程语言特有的序列化问题一样,就是它只能用于Python,并且可能不同版本的Python彼此都不兼容,因此,只能用Pickle保存那些不重要的数据,不能成功地反序列化也没关系。

struct

1
messageLength = struct.pack('>I', len(keyPickle))

struct.pack()将数据转换为Python的字符串类型,便于以流的方式操作,参数用来结构化字符串。>表示big-endian(大端存储方式,高位放在高地址段),I表示unsigned int类型。

这行代码就是用来制作一个表示接下来发送的信息长度的流,用来告知客户端。

random

random.sample(population, k)

Return a k length list of unique elements chosen from the population sequence or set.

1
2
3
4
5
6
7
import random

a = [[1, 2], [3, 4], [5, 6]]

b = random.sample(a, 1)

print(b)

某次输出:

1
[[1, 2]]

注意,需要用b[0]访问到想要的结果。

总结

batch_cache.py文件运行后,在客户端链接之前,完成的工作有:

  • 启动新线程 (self.worker) 执行__memory_monitor(self)方法(当客户端链接之后才开始起作用)

  • 载入数据集全部图片信息,选取关键信息存放在变量self.all_keys、self.image_paths中(通过create_keys()和add_dataset()函数完成)。注意,这里的self.all_keys中的帧信息是筛选过的,确保在num_unrolls之后的那一帧还在当前的视频序列,也就是除去每个视频的后num_unrolls帧。

    • self.all_keys: set, element’s format is [num (0, 1, 2, …), vv, trackId, imNum]
    • self.image_paths: list, element’s format is str
  • 从数据集 (self.all_keys) 中随机选择至少32组数据添加到self.vals、self.keys中(可以把这两个变量当做缓冲区)(由__random_load()函数完成)

    • self.keys: element’s format is like self.all_keys

    • self.vals: element’s format is a list (包含选择的这帧及往后的num_unrolls个相邻帧(set),每个set中是图片内容和形状字节形式)

      由变量决定随机选择多少组,每组有num_unrolls个帧的信息,只有第一帧key信息。

  • 创建TCP服务器,等待客户端的链接

在客户端链接之后:

  • get_sample()函数随机从缓冲区中选择一组,进行发送,并标记这组数据(self.data_hits加1)
  • 一旦开始发送数据给客户端,__memory_monitor(self)方法就往缓冲区(self.vals、self.keys)添加随机数据,最大存放100组(内存占用考虑)
  • __random_load()函数在客户端链接后负责更新缓冲区(将最大self.data_hits对应的数据替换成下一个随机数据)
  • 客户端连接后,除了主线程,还有两个线程在运行:一个负责__random_load()函数——更新缓冲区,一个负责get_sample()函数——随机选择数据发送。主线程等待两个线程的结束。

问题

新加入缓冲区内的数据(随机选取)有可能之前发送过

unrolled_solver.py (☆key script as client)

代码中的知识点

np.set_printoptions( )

文档

可以控制Numpy数据的输出格式。

tf.logging.set_verbosity(tf.logging.INFO)

tensorflow使用五级日志:DEBUG, INFO, WARN, ERROR, and FATAL,默认为WARN

把日志设置为INFO级别:

1
tf.logging.set_verbosity(tf.logging.INFO)

会打印配置信息,创建CheckpointSaverHook,并且把过程打出来。每百步的损失值,消耗时间,每秒跑的步数等:

1
2
3
4
5
6
7
> INFO:tensorflow:step = 40001, loss = 4.66333
> INFO:tensorflow:global_step/sec: 458.689
> INFO:tensorflow:step = 40101, loss = 1.37071 (0.219 sec)
> INFO:tensorflow:global_step/sec: 602.375
> INFO:tensorflow:step = 40201, loss = 2.09407 (0.164 sec)
> INFO:tensorflow:global_step/sec: 781.205
>

tf_dataset.py(☆☆key script used in unrolled_solver.py to generate standard data)

get_dataset()

http://geyao1995.com/TensorFlow_learning/

get_data_sequence()

主要获取两种 data sequence:1. ILSVRC VID的真实视频数据。2.ILSVRC DET生成的模拟视频数据。主要由useSimulator这个变量控制。

获取真实数据

获取由batch_cache.py进程传递过来的数据,得到的数据是数据集中随机的连续time_steps帧(同一个视频序列的)。将每一帧与前一帧放到一个维度上,令最后的输出维度是(time_steps×2, 227, 227, 3),符合输入。为了泛华,随机对bbox加上干扰(self.add_noise),由realMotion变量控制。

注意 Ⅲ. METHOD 部分 Part B Learning to Fix Mistake:

However, if we always provide the network with ground-truth crops, at test time it quickly accumulates more drift than it has ever encountered, and loses track of the object. To counteract this, we employ a regime that initially relies on ground-truth crops, but over time the network uses its own predictions to
generate the next crops. We initially only use the ground truth crops, and as we double the number of unrolls, we increase the probability of using predicted crops to first 0.25, then subsequently 0.5 and 0.75.

if gtType < USE_NETWORK_PROB:中的代码就在完成这件事,用预测得到的下一帧bbox代替下一帧的gt bbox。但是USE_NETWORK_PROB这个变量在代码中并没有修改,可能得手动调。

###生成模拟数据

运行报错

错误一:需要添加path

1
Couldn't open CUDA library libcupti.so.9.0. LD_LIBRARY_PATH: /usr/local/cuda/lib64

一般是/usr/local/cuda-8.0/lib64/usr/local/cuda-9.0/lib64,这里创建了软链接cuda,所以就是/usr/local/cuda/lib64,在该目录下寻找libcupti.so.9.0.文件,该文件位置在:/usr/local/cuda/extras/CUPTI/lib64,将这个路径添加到环境变量中就行了。

PyCharm -> Run Editconfigurations-> Environment Variables:

Name设置LD_LIBRARY_PATH, Value设置/usr/local/cuda/extras/CUPTI/lib64, 如果有多个值,中间用:隔开。

错误二:端口被占用

当程序未正常关闭,可能产生这个错误,从终端kill占用端口的无用进程即可。

1
sudo lsof -i:9997

显示:

1
2
COMMAND   PID   USER   FD   TYPE  DEVICE SIZE/OFF NODE NAME
python 32649 ubuntu 4u IPv4 1163650 0t0 TCP localhost:9997 (LISTEN)

利用PID杀掉进程:

1
sudo kill 32649

----------over----------


文章标题:Re3网络的训练过程

文章作者:Ge垚

发布时间:2018年08月02日 - 13:08

最后更新:2019年01月30日 - 11:01

原始链接:http://geyao1995.com/Re3_train/

许可协议: 署名-非商业性使用-禁止演绎 4.0 国际 转载请保留原文链接及作者。