tspace is an data pipleline framework for deep reinforcement learning with IO interface, processing and configuration. The current code base depicts an automotive implementation. The goal of the system is to increase the energy efficiency (reward) of a BEV by imposing modification on parameters (action) of powertrain controller, the VCU, based on observations of the vehicle (state), i.e. speed, acceleration, electric engine current, voltage etc. The main features are:
- works in both training and inferrence mode, supporting
- coordinated ETL and ML pipelines,
- online and offline training,
- local and distributed training;
- supports multiple models:
- reinforcement learning models with DDPG and
- recurrent models (RDPG) for time sequences with arbitrary length;
- offline reinforcement learning with “Implict Diffusion Q-Learning” (IDQ)
- the data pipelines are compatible to both ETL and ML dataflow with
- support of multiple data sources (local CAN or remote cloud object storage),
- stateful time sequence processing with sequential model and
- support of both NoSQL database, local and cloud data storage.
The diagram shows the basic architecture of tspace.
It is the entry point of the tspace. It orchestrates the whole ETL and
ML workflow.
- It configures KvaserCAN, RemoteCAN, Cruncher, Agent, Model, Database, Pipeline.
- It manages the scheduling of two primary threads in the first tier of
cascaded threading pools in
tspace.avatar.main. - It selects the either KvaserCAN or RemoteCAN as the vehicle interface for reading the observation and applying the action.
It is implemented with
Kvaser
which provides
-
a local interface for reading the observation (CAN messages of vehicle states) via Kvaser using
udp_contextto get CAN messages as json data from a local udp server. Then it encodes the raw json data into a pandas.DataFrame for forwarding through the data pipeline toCruncher. -
It provides a local interface for applying the action (flashing parameters) onto the vehicle ECU (VCU). Before sending the action, it decodes the action from the pandas.DataFrame into packed string buffer and then sends it to the ECU by calling
send_float_arrayfromVehicleInterface.consume. -
The control messages for training HMI go through the same UDP port. They are used to modify the threading events to control the episodic training process with
VehicleInterface.hmi_control.
It provides a remote interface to the vehicle via the object storage
system on the cloud sent by the onboard TBox. It’s implemented with
Cloud:
-
It reads the observation (CAN messages of vehicle states) from the cloud object storage system through
RemoteCanClient.get_signals. It then encodes the raw json data into a pandas.DataFrame and forward it toCruncherthrough the data pipeline. -
It sends the action (flashing parameters) to the vehicle ECU (VCU) in the shared
VehicleInterface.consumeby callingRemoteCanClient.send_torque_map, which decodes the action from the pandas.DataFrame into raw json string. -
It selects the training HMI to get the vehicle and driver information as configuration with
Cloud.hmi_capture_from_udpfor local udp server, withCloud.hmi_capture_from_rmqfor remote RocketMQ server, withCloud.hmi_capture_from_dummyfor pure inference mode without training or updating models. It shares the same control logicVehicleInterface.hmi_controlwith KvaserCAN.
It is main pivot of the data pipeline for pre-processing the observation and post-processing the action:
-
The
Cruncher.filterreveives the observation through the data pipeline from KvaserCAN or RemoteCAN. It pre-processes the input data into the quadruple with a timestamp$(timestamp, state, action, reward, state')$ and give it to the reinforcement AgentDPG, subsequently its childDDPGorRDPG, for inferring an optimal action determined by its current policy. After getting the prediction of the agent, it encodes the prediction result into an action object and forwards it toVehicleInterface.consumeto be flashed onto VCU. -
It collects the critic, actor loss, the total reward for each episode, the running reward and the action at the end of the episode. It also saves the model checkpoint and the training log locally.
It provides a wrapper for the reinforcement learning model with
DPG:
-
It has an interface to data storage:
-
retrieves the observation meta information and database configuration from
Avatar, -
initializes repo interface
Buffer, subsequentlyMongoBufferorDaskBufferwhich then initializes the database connection withMongoPoolorDaskPoolrespectively.
-
-
It transfers observation data to the neural network:
-
initializes the episode states,
-
defines abstract methods
DPG.actor_predict,DPG.train,DPG.get_losses,DPG.soft_update_target,DPG.init_checkpoint,DPG.save_ckpt,DPG.touch_gpufor concrete implementations in child classesDDPGandRDPG, -
provides the concrete methods
DPG.start_episode,DPG.end_episode,DPG.deposit,DPG.deposit_episode. -
DPG.touch_gpuis used to warm up the GPU before starting inference.
-
- provides methods to create, load or initialize the Deep Deterministic Policy Gradient Model, or restore checkpoints to it. It also exports the tflite model.
- It provides the concrete methods for the abstract ones in the
DPGinterface. DDPG.infer_single_sampleis the inference method with graph optimization via tf.function.DDPG.sample_minibatchprovides a minibatch sampled from the buffer. It handles the bootstrap when the buffer is empty thus there is no samples in theBufferwhen the first episode has not ended.DDPG.update_with_batchenforces the back propagation and applies the weight update to the actor and critic network duringDDPG.train.
- provides methods to create, load or initialize the Recurrent Deterministic Policy Gradient Model, or restore checkpoints to it.
- It provides the concrete methods for the abstract ones in the
DPGinterface. RDPG.actor_predict_stepis the inference method with graph optimization via tf.function.RDPG.train_stepis the training method with graph optimization via tf.function. It also applies the weight update to the actor and critic networkRDPG.trainsamples a ragged minibatch of episodes with different lengths from the buffer. It can handle training of time sequences with arbitrary length by truncated back propagation through time (TBPTT) with splitting the episodes and looping over the subsequences with Masking layers to update the weights byRDPG.train_step.
- provides methods to create and initialize the Implicit Diffusion Q-learning Model.
- The implementation of model is based on the repo jaxrl5 with Jax and Flax interface.
- It provides the concrete methods for the abstract ones in the
DPGinterface. IDQL.actor_predictis the inference method.IDQL.trainis the training method. Jaxrl5 takes care of the weight update to the actor and critic and the value network. It samples a minibatch of tuples (state, action, reward, next state) from the buffer.
It’s the neural network model for the reinforcement learning agent. For
now it’s only implemented for
RDPG
in
SeqActor
and
SeqCritic.
It is the actor network with two recurrent LSTM layers, two dense layers and a Masking layer for handling ragged input sequence.
SeqActor.predictoutputs the action given the state for inference, thus the batch dimension has to be one.SeqActor.evaluate_actionsoutputs the action given a batch of states for training. It’s used in the training loop to get the prediction of the target actor network to calculate the critic loss.- It handles the ragged input sequences with Masking layer and the stateful recurrent layers for TBPTT
- For inference,
SeqCriticis not used and onlySeqActoris required.
It is the critic network with two recurrent LSTM layers and two dense layer and a Masking layer for handling ragged input sequence.
SeqCritic.evaluate_qgives the Q-value given a batch of the state and action. It’s used in the training loopRDPG.train_stepto calculate the critic and actor loss.
represents the data storage in the repository pattern with two
polymorphic abstraction layers
Buffer
and
Pool.
is an abstract class. It provides a view of data storage to the agent:
- Agent uses the abstract methods
Buffer.load,Buffer.saveandBuffer.closeloads or saves data from or to thePool, and closes the connection to thePool. - The abstract
Buffer.samplesamples a minibatch from thePool. It needs the child ofBufferto implement the concrete efficient sampling method, which depends on the underlying data storage system. - The concrete methode
Buffer.storestore the whole episode data into thePool - The concrete methode
Buffer.findsimply callsPool.findto find the data with the given query.
It’s a concrete class for the underlying NoSQL database MongoDB.
- It implements the abstract methods required by the
Bufferinterface. MongoBuffer.decode_batch_recordsprepare the sample batch data fromMongoPoolinto a compliant format for agent training.- It can handle both DDPG record data type and RDPG episode data type.
It’s a concrete class for the distributed data storage system Dask.
- It implements the abstract methods required by the
Bufferinterface. DaskBuffer.decode_batch_recordsprepare the sample batch data fromDaskPoolinto a compliant format for agent training.- It can handle both DDPG record data type and RDPG episode data type.
is an abstract class. It’s the interface for the underlying data
storage. For the moment, it’s implemented with
MongoPool
and
DaskPool.
- It defines the abstract methods
Pool.load,Pool.close,Pool.store,Pool.delete,Pool.find,Pool.sampleandPool._countfor the concrete classes to implement. - It defines
PoolQueryas the query object forPool.sample,Pool.findandPool._countmethod. - It implements the iterable protocol with
Pool.__iter__andPool.__getitem__for the concrete classes to implement an efficient indexing method.
It’s a concrete class for the underlying NoSQL database MongoDB with time series support. It handles both record data type and episode data type with MongoDB collection features.
- It provides the interface to the MongoDB database with the pymongo library.
- It implements the abstract methods required by the
Poolinterface. MongoPool.store_recordstores the record data into the MongoDB database forDDPGagent.MongoPool.store_episodestores the episode data into the MongoDB database forRDPGagent.
It’s an abstract class for the distributed data storage system Dask, since we have to use different backends: Parquet for record data type and avro for episode data type.
- It supports both local file storage and remote object storage with the dask library.
- It defines the generic data type for the abstract method required by
the
Poolinterface. The generic data type can then be specialized by the concrete classes either as dask.DataFrame for record data type or dask.Bag for episode data type.
is a concrete class for the record data type with the Parquet file format as backend storage.
- It implements the abstract methods required by the
DaskPoolinterface andPoolsubsequently. ParquetPool.sampleprovides an efficient unified sampling interface via Dask.DataFrame to a Parquet storage either locally or remotely.ParquetPool.get_queryprovides the query object through Dask indexing for theParquetPool.samplemethod.
is a concrete class for the episode data type with the avro file format as backend storage.
- It implements the abstract methods required by the
DaskPoolinterface andPoolsubsequently. AvroPool.sampleprovides an efficient unified sampling interface via Dask.Bag to a avro storage either locally or remotely.AvroPool.get_queryprovides the query object through Dask indexing for theAvroPool.samplemethod.
provides all classes for the configuration of the tspace framework. Most of them serve as meta information for the observation data and used in later indexing or grouping for efficient sampling. It includes
Truckwith childrenTruckInCloudandTruckInFieldwith different interfaces using mixinsTboxMixinandKvaserMixin. It provides a managed truck list and two dictionaries for quick access to the truck configuration;Driverwith properties to be store in the meta information of the observation data;TripMessengerfor different the HMI input source;CANMessengerfor different CAN message source;DBConfigfor management of the database configuration;
The schduling of ETL and ML training and inference is carried out as two levels of cascaded threading pools.
is managed by
Avatar with
two primary threads in
tspace.avatar.main:
- The first primary thread is for data caputring
- The second primary thread is for training and inference
calls
VehicleInterface.ignite,
which is shared by
Kvaser
and
Cloud.
It just starts a secondary threading pool containing six threads
VehicleInterface.produceget the raw data either from the local UDP server as inKvaseror the remote cloud object storage as inCloudand forward it to the raw data pipeline. In case ofKvaser, it also gets the training HMI control messages from the same UDP server and put them in the HMI data pipeline.VehicleInterface.hmi_controlmanages the episodic state machine to control the training and inference process.VehicleInterface.countdownhandles the episode end with a countdown timer to synchronize the data caputring is aligned with the episode end event.VehicleInterface.filtertransforms the raw input json object into pandas.DataFrame and forward it to the input data pipeline ofCruncher.filterthread.VehicleInterface.consumeis responsible for fetching the action object from the output data pipeline ofCruncher.filterthread and having it flashed on the vehicle ECU (VCU).VehicleInterface.watch_dogprovides a watchdog to monitor the health of the data capturing process and the training process. It triggers the system stop if the observation or action quality is below a threshold.
call
Cruncher.filter.
Importantly, all processing in this thread is done synchronously in
order to preserve the order of the time sequence, thus the causality of
the oberservation and action.
- It gets the data through the input pipeline and delegates the data to the agent for training or inference.
- After getting the prediction from the agent, it encodes the prediction
result into an action object and forwards it through the output
pipeline to
VehicleInterface.consumeto have it flashed on VCU. - It also controls the training loop, the inference loop and manage the training log and model checkpoint.
- This thread is synchronized with the threads in the secondary
threading pool with pre-defined
threading.Event:start_event,stop_event,flash_event,interrupt_eventandexit_event.
- Add time sequence embedding database support with LanceDB for TimeGPT
- Batch mode for large scale inference and training with Unit of Work pattern
- Add schemes for serializing generic time series data
pip install tspace