# Copyright 2018 Tensorforce Team. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import importlib
import json
import os
from threading import Thread
from tensorforce import TensorforceError, util
import tensorforce.environments
[docs]class Environment(object):
"""
Tensorforce environment interface.
"""
[docs] @staticmethod
def create(environment, **kwargs):
"""
Creates an environment from a specification.
Args:
environment (specification): JSON file, specification key, configuration dictionary,
library module, or `Environment` subclass
(<span style="color:#C00000"><b>required</b></span>).
kwargs: Additional arguments.
"""
if isinstance(environment, Environment):
# TODO: asserts???????
return environment
elif isinstance(environment, dict):
# Dictionary specification
util.deep_disjoint_update(target=kwargs, source=environment)
environment = kwargs.pop('environment', kwargs.pop('type', 'default'))
assert environment is not None
return Environment.create(environment=environment, **kwargs)
elif isinstance(environment, str):
if os.path.isfile(environment):
# JSON file specification
with open(environment, 'r') as fp:
environment = json.load(fp=fp)
util.deep_disjoint_update(target=kwargs, source=environment)
environment = kwargs.pop('environment', kwargs.pop('type', 'default'))
assert environment is not None
return Environment.create(environment=environment, **kwargs)
elif '.' in environment:
# Library specification
library_name, module_name = environment.rsplit('.', 1)
library = importlib.import_module(name=library_name)
environment = getattr(library, module_name)
environment = environment(**kwargs)
assert isinstance(environment, Environment)
return environment
else:
# Keyword specification
environment = tensorforce.environments.environments[environment](**kwargs)
assert isinstance(environment, Environment)
return environment
else:
assert False
def __init__(self):
# first two arguments, if applicable: level, visualize=False
self.observation = None
self.thread = None
def __str__(self):
return self.__class__.__name__
[docs] def states(self):
"""
Returns the state space specification.
Returns:
specification: Arbitrarily nested dictionary of state descriptions with the following
attributes:
<ul>
<li><b>type</b> (<i>"bool" | "int" | "float"</i>) – state data type
(<span style="color:#00C000"><b>default</b></span>: "float").</li>
<li><b>shape</b> (<i>int | iter[int]</i>) – state shape
(<span style="color:#C00000"><b>required</b></span>).</li>
<li><b>num_states</b> (<i>int > 0</i>) – number of discrete state values
(<span style="color:#C00000"><b>required</b></span> for type "int").</li>
<li><b>min_value/max_value</b> (<i>float</i>) – minimum/maximum state value
(<span style="color:#00C000"><b>optional</b></span> for type "float").</li>
</ul>
"""
raise NotImplementedError
[docs] def actions(self):
"""
Returns the action space specification.
Returns:
specification: Arbitrarily nested dictionary of action descriptions with the following
attributes:
<ul>
<li><b>type</b> (<i>"bool" | "int" | "float"</i>) – action data type
(<span style="color:#C00000"><b>required</b></span>).</li>
<li><b>shape</b> (<i>int > 0 | iter[int > 0]</i>) – action shape
(<span style="color:#00C000"><b>default</b></span>: scalar).</li>
<li><b>num_actions</b> (<i>int > 0</i>) – number of discrete action values
(<span style="color:#C00000"><b>required</b></span> for type "int").</li>
<li><b>min_value/max_value</b> (<i>float</i>) – minimum/maximum action value
(<span style="color:#00C000"><b>optional</b></span> for type "float").</li>
</ul>
"""
raise NotImplementedError
[docs] def max_episode_timesteps(self):
"""
Returns the maximum number of timesteps per episode.
Returns:
int: Maximum number of timesteps per episode.
"""
return None
[docs] def close(self):
"""
Closes the environment.
"""
if self.thread is not None:
self.thread.join()
self.observation = None
self.thread = None
[docs] def reset(self):
"""
Resets the environment to start a new episode.
Returns:
dict[state]: Dictionary containing initial state(s) and auxiliary information.
"""
raise NotImplementedError
# if self.observation is not None or self.thread is not None:
# raise TensorforceError(message="Invalid execute.")
# self.start_reset()
# self.thread.join()
# states, _, _ = self.observe()
# if self.observation is not None:
# raise TensorforceError(message="Invalid start_reset/observe implementation.")
# return states
[docs] def execute(self, actions):
"""
Executes the given action(s) and advances the environment by one step.
Args:
actions (dict[action]): Dictionary containing action(s) to be executed
(<span style="color:#C00000"><b>required</b></span>).
Returns:
((dict[state], bool | 0 | 1 | 2, float)): Dictionary containing next state(s), whether
a terminal state is reached or 2 if the episode was aborted, and observed reward.
"""
raise NotImplementedError
# if self.observation is not None or self.thread is not None:
# raise TensorforceError(message="Invalid execute.")
# self.start_execute(actions=actions)
# self.thread.join()
# observation = self.observe()
# if self.observation is not None:
# raise TensorforceError(message="Invalid start_execute/observe implementation.")
# return observation
def start_reset(self):
if self.thread is not None:
raise TensorforceError(message="Invalid start_reset.")
self.thread = Thread(target=self.finish_reset)
self.thread.start()
def finish_reset(self):
self.observation = (self.reset(), None, None)
self.thread = None
def start_execute(self, actions):
if self.observation is not None or self.thread is not None:
raise TensorforceError(message="Invalid start_execute.")
self.thread = Thread(target=self.finish_execute, kwargs=dict(actions=actions))
self.thread.start()
def finish_execute(self, actions):
self.observation = self.execute(actions=actions)
self.thread = None
def retrieve_execute(self):
if self.thread is not None:
return None
else:
if self.observation is None:
raise TensorforceError(message="Invalid retrieve_execute.")
observation = self.observation
self.observation = None
return observation