# Copyright 2018 Tensorforce Team. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
from collections import OrderedDict
import importlib
import json
import logging
import os
import random
import time
import numpy as np
import tensorflow as tf
from tensorforce import util, TensorforceError
import tensorforce.agents
[docs]class Agent(object):
"""
Tensorforce agent interface.
"""
[docs] @staticmethod
def create(agent=None, environment=None, **kwargs):
"""
Creates an agent from a specification.
Args:
agent (specification): JSON file, specification key, configuration dictionary,
library module, or `Agent` subclass
(<span style="color:#00C000"><b>default</b></span>: Policy agent).
environment (Environment): Environment which the agent is supposed to be trained on,
environment-related arguments like state/action space specifications will be
extract if given.
kwargs: Additional arguments.
"""
if agent is None:
agent = 'default'
if isinstance(agent, Agent):
# TODO: asserts???????
return agent
elif isinstance(agent, dict):
# Dictionary specification
util.deep_disjoint_update(target=kwargs, source=agent)
agent = kwargs.pop('agent', kwargs.pop('type', 'default'))
return Agent.create(agent=agent, environment=environment, **kwargs)
elif isinstance(agent, str):
if os.path.isfile(agent):
# JSON file specification
with open(agent, 'r') as fp:
agent = json.load(fp=fp)
util.deep_disjoint_update(target=kwargs, source=agent)
agent = kwargs.pop('agent', kwargs.pop('type', 'default'))
return Agent.create(agent=agent, environment=environment, **kwargs)
elif '.' in agent:
# Library specification
library_name, module_name = agent.rsplit('.', 1)
library = importlib.import_module(name=library_name)
agent = getattr(library, module_name)
if environment is not None:
env_spec = dict(states=environment.states(), actions=environment.actions())
if environment.max_episode_timesteps() is not None:
env_spec['max_episode_timesteps'] = environment.max_episode_timesteps()
util.deep_disjoint_update(target=kwargs, source=env_spec)
agent = agent(**kwargs)
assert isinstance(agent, Agent)
return agent
else:
# Keyword specification
if environment is not None:
env_spec = dict(states=environment.states(), actions=environment.actions())
if environment.max_episode_timesteps() is not None:
env_spec['max_episode_timesteps'] = environment.max_episode_timesteps()
util.deep_disjoint_update(target=kwargs, source=env_spec)
agent = tensorforce.agents.agents[agent](**kwargs)
assert isinstance(agent, Agent)
return agent
else:
assert False
def __init__(
# Environment
self, states, actions, max_episode_timesteps=None,
# TensorFlow etc
parallel_interactions=1, buffer_observe=True, seed=None, recorder=None
):
if seed is not None:
assert isinstance(seed, int)
random.seed(n=seed)
np.random.seed(seed=seed)
tf.random.set_random_seed(seed=seed)
# States/actions specification
self.states_spec = util.valid_values_spec(
values_spec=states, value_type='state', return_normalized=True
)
self.actions_spec = util.valid_values_spec(
values_spec=actions, value_type='action', return_normalized=True
)
self.max_episode_timesteps = max_episode_timesteps
# Check for name overlap
for name in self.states_spec:
if name in self.actions_spec:
TensorforceError.collision(
name='name', value=name, group1='states', group2='actions'
)
# Parallel episodes
if isinstance(parallel_interactions, int):
if parallel_interactions <= 0:
raise TensorforceError.value(
name='parallel_interactions', value=parallel_interactions
)
self.parallel_interactions = parallel_interactions
else:
raise TensorforceError.type(name='parallel_interactions', value=parallel_interactions)
# Buffer observe
if isinstance(buffer_observe, bool):
if not buffer_observe and self.parallel_interactions > 1:
raise TensorforceError.unexpected()
if self.max_episode_timesteps is None and self.parallel_interactions > 1:
raise TensorforceError.unexpected()
if not buffer_observe:
self.buffer_observe = 1
elif self.max_episode_timesteps is None:
self.buffer_observe = 100
else:
self.buffer_observe = self.max_episode_timesteps
elif isinstance(buffer_observe, int):
if buffer_observe <= 0:
raise TensorforceError.value(name='buffer_observe', value=buffer_observe)
if self.parallel_interactions > 1:
raise TensorforceError.unexpected()
if self.max_episode_timesteps is None:
self.buffer_observe = buffer_observe
else:
self.buffer_observe = min(buffer_observe, self.max_episode_timesteps)
else:
raise TensorforceError.type(name='buffer_observe', value=buffer_observe)
# Parallel terminal/reward buffers
self.terminal_buffers = np.ndarray(
shape=(self.parallel_interactions, self.buffer_observe),
dtype=util.np_dtype(dtype='long')
)
self.reward_buffers = np.ndarray(
shape=(self.parallel_interactions, self.buffer_observe),
dtype=util.np_dtype(dtype='float')
)
# Parallel buffer indices
self.buffer_indices = np.zeros(
shape=(self.parallel_interactions,), dtype=util.np_dtype(dtype='int')
)
self.timestep = 0
self.episode = 0
# Recorder
if recorder is None:
pass
elif not all(key in ('directory', 'frequency', 'max-traces') for key in recorder):
raise TensorforceError.value(name='recorder', value=list(recorder))
self.recorder_spec = recorder
if self.recorder_spec is not None:
self.record_states = OrderedDict(((name, list()) for name in self.states_spec))
for name, spec in self.actions_spec.items():
if spec['type'] == 'int':
self.record_states[name + '_mask'] = list()
self.record_actions = OrderedDict(((name, list()) for name in self.actions_spec))
self.record_terminal = list()
self.record_reward = list()
self.num_episodes = 0
def __str__(self):
return self.__class__.__name__
[docs] def initialize(self):
"""
Initializes the agent.
"""
if not hasattr(self, 'model'):
raise TensorforceError.missing(name='Agent', value='model')
# Setup Model
# (create and build graph (local and global if distributed), server, session, etc..).
self.model.initialize()
self.reset()
[docs] def close(self):
"""
Closes the agent.
"""
self.model.close()
[docs] def reset(self):
"""
Resets the agent to start a new episode.
"""
self.buffer_indices = np.zeros(
shape=(self.parallel_interactions,), dtype=util.np_dtype(dtype='int')
)
self.timestep, self.episode = self.model.reset()
[docs] def act(
self, states, parallel=0, deterministic=False, independent=False, evaluation=False,
query=None, **kwargs
):
"""
Returns action(s) for the given state(s), needs to be followed by `observe(...)` unless
`independent` is true.
Args:
states (dict[state]): Dictionary containing state(s) to be acted on
(<span style="color:#C00000"><b>required</b></span>).
parallel (int): Parallel execution index
(<span style="color:#00C000"><b>default</b></span>: 0).
deterministic (bool): Whether to apply exploration and sampling
(<span style="color:#00C000"><b>default</b></span>: false).
independent (bool): Whether action is not remembered, and this call is thus not
followed by observe
(<span style="color:#00C000"><b>default</b></span>: false).
evaluation (bool): Whether the agent is currently evaluated, implies and overwrites
deterministic and independent
(<span style="color:#00C000"><b>default</b></span>: false).
query (list[str]): Names of tensors to retrieve
(<span style="color:#00C000"><b>default</b></span>: none).
kwargs: Additional input values, for instance, for dynamic hyperparameters.
Returns:
(dict[action], plus optional list[str]): Dictionary containing action(s), plus queried
tensor values if requested.
"""
assert util.reduce_all(predicate=util.not_nan_inf, xs=states)
# self.current_internals = self.next_internals
if evaluation:
if deterministic or independent:
raise TensorforceError.unexpected()
deterministic = independent = True
# Auxiliaries
auxiliaries = OrderedDict()
if isinstance(states, dict):
states = dict(states)
for name, spec in self.actions_spec.items():
if spec['type'] == 'int' and name + '_mask' in states:
auxiliaries[name + '_mask'] = states.pop(name + '_mask')
# Normalize states dictionary
states = util.normalize_values(
value_type='state', values=states, values_spec=self.states_spec
)
# Batch states
states = util.fmap(function=(lambda x: np.asarray([x])), xs=states, depth=1)
auxiliaries = util.fmap(function=(lambda x: np.asarray([x])), xs=auxiliaries, depth=1)
# Model.act()
if query is None:
actions, self.timestep = self.model.act(
states=states, auxiliaries=auxiliaries, parallel=parallel,
deterministic=deterministic, independent=independent, **kwargs
)
else:
actions, self.timestep, queried = self.model.act(
states=states, auxiliaries=auxiliaries, parallel=parallel,
deterministic=deterministic, independent=independent, query=query, **kwargs
)
if self.recorder_spec is not None and not independent:
for name in self.states_spec:
self.record_states[name].append(states[name])
for name, spec in self.actions_spec.items():
self.record_actions[name].append(actions[name])
if spec['type'] == 'int':
if name + '_mask' in auxiliaries:
self.record_states[name].append(auxiliaries[name + '_mask'])
else:
shape = (1,) + spec['shape'] + (spec['num_values'],)
self.record_states[name].append(
np.full(shape, True, dtype=util.np_dtype(dtype='bool'))
)
# Unbatch actions
actions = util.fmap(function=(lambda x: x[0]), xs=actions, depth=1)
# Reverse normalized actions dictionary
actions = util.unpack_values(
value_type='action', values=actions, values_spec=self.actions_spec
)
# if independent, return processed state as well?
if query is None:
return actions
else:
return actions, queried
[docs] def observe(self, reward, terminal=False, parallel=0, query=None, **kwargs):
"""
Observes reward and whether a terminal state is reached, needs to be preceded by
`act(...)`.
Args:
reward (float): Reward
(<span style="color:#C00000"><b>required</b></span>).
terminal (bool | 0 | 1 | 2): Whether a terminal state is reached or 2 if the
episode was aborted (<span style="color:#00C000"><b>default</b></span>: false).
parallel (int): Parallel execution index
(<span style="color:#00C000"><b>default</b></span>: 0).
query (list[str]): Names of tensors to retrieve
(<span style="color:#00C000"><b>default</b></span>: none).
kwargs: Additional input values, for instance, for dynamic hyperparameters.
Returns:
(bool, optional list[str]): Whether an update was performed, plus queried tensor values
if requested.
"""
assert util.reduce_all(predicate=util.not_nan_inf, xs=reward)
if query is not None and self.parallel_interactions > 1:
raise TensorforceError.unexpected()
if isinstance(terminal, bool):
terminal = int(terminal)
if self.recorder_spec is not None:
self.record_terminal.append(terminal)
self.record_reward.append(reward)
if terminal > 0:
self.num_episodes += 1
if self.num_episodes == self.recorder_spec.get('frequency', 1):
directory = self.recorder_spec['directory']
if os.path.isdir(directory):
files = sorted(
f for f in os.listdir(directory)
if os.path.isfile(os.path.join(directory, f))
and f.startswith('trace-')
)
else:
os.makedirs(directory)
files = list()
max_traces = self.recorder_spec.get('max-traces')
if max_traces is not None and len(files) > max_traces - 1:
for filename in files[:-max_traces + 1]:
filename = os.path.join(directory, filename)
os.remove(filename)
filename = 'trace-{}-{}.npz'.format(
self.episode, time.strftime('%Y%m%d-%H%M%S')
)
filename = os.path.join(directory, filename)
self.record_states = util.fmap(
function=np.concatenate, xs=self.record_states, depth=1
)
self.record_actions = util.fmap(
function=np.concatenate, xs=self.record_actions, depth=1
)
self.record_terminal = np.asarray(self.record_terminal)
self.record_reward = np.asarray(self.record_reward)
np.savez_compressed(
filename, **self.record_states, **self.record_actions,
terminal=self.record_terminal, reward=self.record_reward
)
self.record_states = util.fmap(
function=(lambda x: list()), xs=self.record_states, depth=1
)
self.record_actions = util.fmap(
function=(lambda x: list()), xs=self.record_actions, depth=1
)
self.record_terminal = list()
self.record_reward = list()
self.num_episodes = 0
# Update terminal/reward buffer
index = self.buffer_indices[parallel]
self.terminal_buffers[parallel, index] = terminal
self.reward_buffers[parallel, index] = reward
index += 1
if self.max_episode_timesteps is not None and index > self.max_episode_timesteps:
raise TensorforceError.unexpected()
if terminal > 0 or index == self.buffer_observe or query is not None:
# Model.observe()
if query is None:
updated, self.episode = self.model.observe(
terminal=self.terminal_buffers[parallel, :index],
reward=self.reward_buffers[parallel, :index], parallel=parallel, **kwargs
)
else:
updated, self.episode, queried = self.model.observe(
terminal=self.terminal_buffers[parallel, :index],
reward=self.reward_buffers[parallel, :index], parallel=parallel, query=query,
**kwargs
)
# Reset buffer index
self.buffer_indices[parallel] = 0
else:
# Increment buffer index
self.buffer_indices[parallel] = index
updated = False
if query is None:
return updated
else:
return updated, queried
[docs] def save(self, directory=None, filename=None, append_timestep=True):
"""
Saves the current state of the agent.
Args:
directory (str): Checkpoint directory
(<span style="color:#00C000"><b>default</b></span>: directory specified for
TensorFlow saver).
filename (str): Checkpoint filename
(<span style="color:#00C000"><b>default</b></span>: filename specified for
TensorFlow saver).
append_timestep: Whether to append the current timestep to the checkpoint file
(<span style="color:#00C000"><b>default</b></span>: true).
Returns:
str: Checkpoint path.
"""
# TODO: Messes with required parallel disentangling, better to remove unfinished episodes
# from memory, but currently entire episode buffered anyway...
# # Empty buffers before saving
# for parallel in range(self.parallel_interactions):
# index = self.buffer_indices[parallel]
# if index > 0:
# # if self.parallel_interactions > 1:
# # raise TensorforceError.unexpected()
# self.episode = self.model.observe(
# terminal=self.terminal_buffers[parallel, :index],
# reward=self.reward_buffers[parallel, :index], parallel=parallel
# )
# self.buffer_indices[parallel] = 0
return self.model.save(
directory=directory, filename=filename, append_timestep=append_timestep
)
[docs] def restore(self, directory=None, filename=None):
"""
Restores the agent.
Args:
directory (str): Checkpoint directory
(<span style="color:#00C000"><b>default</b></span>: directory specified for
TensorFlow saver).
filename (str): Checkpoint filename
(<span style="color:#00C000"><b>default</b></span>: latest checkpoint in
directory).
"""
if not hasattr(self, 'model'):
raise TensorforceError.missing(name='Agent', value='model')
if not self.model.is_initialized:
self.model.initialize()
self.timestep, self.episode = self.model.restore(directory=directory, filename=filename)
[docs] def get_output_tensors(self, function):
"""
Returns the names of output tensors for the given function.
Args:
function (str): Function name
(<span style="color:#C00000"><b>required</b></span>).
Returns:
list[str]: Names of output tensors.
"""
if function in self.model.output_tensors:
return self.model.output_tensors[function]
else:
raise TensorforceError.unexpected()
[docs] def get_query_tensors(self, function):
"""
Returns the names of queryable tensors for the given function.
Args:
function (str): Function name
(<span style="color:#C00000"><b>required</b></span>).
Returns:
list[str]: Names of queryable tensors.
"""
if function in self.model.query_tensors:
return self.model.query_tensors[function]
else:
raise TensorforceError.unexpected()
[docs] def get_available_summaries(self):
"""
Returns the summary labels provided by the agent.
Returns:
list[str]: Available summary labels.
"""
return self.model.get_available_summaries()
def should_stop(self):
return self.model.monitored_session.should_stop()