|
- import matplotlib.pyplot as plt
- import numpy as np
- import seaborn as sns
- import os
-
- sns.set(style='ticks')
- sns.set_context("notebook",font_scale=1.5,rc={"lines.linewidth":2.5,"lines.markersize":5.5})
- DIGIT=5000
-
- def iter_dirs(dir,env_id):
- candidate_dirs = []
- for subdir in os.listdir(dir):
- if env_id in subdir:
- for filename in os.listdir(dir+subdir+"/"):
- candidate_dirs.append(dir+subdir+"/"+filename)
- return candidate_dirs
-
- def smooth(ys,coef=0.98):
- result = []
- last = ys[0]
- for y in ys:
- last = coef*last + (1-coef)*y
- result.append(last)
- return result
-
- def divide_group(data,coef=0.999):
- env_min = min([episode_data['env'] for episode_data in data])
- env_max = max([episode_data['env'] for episode_data in data])
- x_time = [[] for i in range(env_min,env_max+1)]
- x_frame = [[] for i in range(env_min,env_max+1)]
- y_score = [[] for i in range(env_min,env_max+1)]
- for episode_data in data:
- x_time[episode_data['env']].append(int(episode_data['clocktime']//20*20.0))
- x_frame[episode_data['env']].append(int(episode_data['frame']+DIGIT//4)//DIGIT*DIGIT)
- y_score[episode_data['env']].append(episode_data['score'])
- result_time = []
- result_frame = []
- result_score = []
- for i in range(env_min,env_max+1):
- y_score[i] = smooth(y_score[i],coef)
- result_time.extend(x_time[i])
- result_frame.extend(x_frame[i])
- result_score.extend(y_score[i])
- return result_time,result_frame,result_score
- #ALGOS = ['vpg','vpg-obsnorm','vpg-rewnorm','vpg-advnorm']
- #ALGOS = ['a2c','a2c-gradclip','a2c-gae','a2c-obsnorm','a2c-rewnorm','A2C']
- #ALGOS = ['a2c','a2c-gae','a2c-gradclip','a2c-obsnorm','a2c-rewnorm','a2c-advnorm']
- ALGOS = ['A2C']
- TASK ="HalfCheetah-v3"
- PREFIX = "./results/"
-
- fig = plt.figure(figsize=(12.0,6.0))#
- for algo in ALGOS:
- dirs = iter_dirs(PREFIX+algo+"/",TASK)
- x_times = []
- x_frames = []
- y_scores = []
- for dir in dirs:
- if ".npy" in dir:
- times,frames,scores = divide_group(np.load(dir,allow_pickle=True),coef=0.90)
- x_times.extend(times)
- x_frames.extend(frames)
- y_scores.extend(scores)
-
- ax = plt.subplot(1,2,1)
- sns.lineplot(x=x_times,y=y_scores,legend='brief',label=algo)
- ax.set_title("%s-score-time"%TASK)
- ax.set_xlabel("seconds")
- plt.grid(linestyle="-.")
-
- ax = plt.subplot(1,2,2)
- sns.lineplot(x=x_frames,y=y_scores,legend='brief',label=algo)
- ax.set_title("%s-score-frame"%TASK)
- ax.set_xlabel("frames")
- plt.grid(linestyle="-.")
- plt.ticklabel_format(style="sci",scilimits=(0,1),axis='x')
-
- plt.show()
-
-
-
-
-
-
-
|