Skip to content

Commit

Permalink
register allocation eviction model training.
Browse files Browse the repository at this point in the history
  • Loading branch information
Yundi Qian committed Aug 3, 2021
1 parent 9bada5d commit 4e301f5
Show file tree
Hide file tree
Showing 37 changed files with 18,532 additions and 6,141 deletions.
22 changes: 22 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,28 @@ Gradient.
For more details about MLGO, please refer to our paper
[MLGO: a Machine Learning Guided Compiler Optimizations Framework](https://arxiv.org/abs/2101.04808).

## Pretrained models

We occasionally release pretrained models that may be used as-is with LLVM.
Models are released as github releases, and are named as
[task]-[major-version].[minor-version].The versions are semantic: the major
version corresponds to breaking changes on the LLVM/compiler side, and the minor
version corresponds to model updates that are independent of the compiler.

When building LLVM, there is a flag `-DLLVM_INLINER_MODEL_PATH` which you may
set to the path to your inlining model. If the path is set to `download`, then
cmake will download the most recent (compatible) model from github to use. Other
values for the flag could be:

```sh
# Model is in /tmp/model, i.e. there is a file /tmp/model/saved_model.pb along
# with the rest of the tensorflow saved_model files produced from training.
-DLLVM_INLINER_MODEL_PATH=/tmp/model

# Download the most recent compatible model
-DLLVM_INLINER_MODEL_PATH=download
```

## Prerequisites

Currently, the assumption for the is:
Expand Down
60 changes: 60 additions & 0 deletions compiler_opt/rl/data_collector.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
"""Data collection module."""

import abc
import time

from typing import Iterator, Tuple, Dict
from tf_agents.trajectories import trajectory

Expand Down Expand Up @@ -48,3 +50,61 @@ def on_dataset_consumed(self,
Args:
dataset_iterator: the dataset_iterator that has been consumed.
"""


class EarlyExitChecker:
"""Class which checks if it is ok to early-exit from data collection."""

def __init__(self, deadline, thresholds, num_modules):
"""Initialize the early exit checker.
Args:
deadline: The deadline for data collection, in seconds.
thresholds: Early exit thresholds, e.g. [(0.8, 0.5)] means early exit is
allowable if 80% of data has been collected and we've waited 50% of the
maximum waiting time.
num_modules: How many total modules we are waiting for.
"""
self._deadline = deadline
self._thresholds = thresholds
self._num_modules = num_modules
self._start_time = time.time()
self._waited_time = 0

def _should_exit(self, collected: int):
"""Checks whether we should exit early.
If collected is negative, _should_exit will always return false.
Args:
collected: The amount data we have collected.
Returns:
True if we should exit, otherwise False.
"""
if collected < 0:
return False

self._waited_time = round(time.time() - self._start_time)
for (data_threshold, deadline_threshold) in self._thresholds:
if ((collected >= self._num_modules * data_threshold) and
(self._waited_time >= self._deadline * deadline_threshold)):
return True
return False

def wait(self, get_num_finished_work):
"""Waits until the deadline has expired or an early exit is possible.
Args:
get_num_finished_work: a callable object which returns the amount of
finished work.
Returns:
The amount of time waited.
"""
while not self._should_exit(get_num_finished_work()):
time.sleep(1)
return self.waited_time()

def waited_time(self):
return self._waited_time
52 changes: 52 additions & 0 deletions compiler_opt/rl/data_collector_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
# coding=utf-8
# Copyright 2020 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Tests for data_collector."""

from unittest import mock

from absl.testing import absltest

from compiler_opt.rl import data_collector


class DataCollectorTest(absltest.TestCase):

@mock.patch('time.time')
def test_early_exit(self, mock_time):
mock_time.return_value = 0
early_exit = data_collector.EarlyExitChecker(
deadline=10, thresholds=[(.9, 0), (.5, .5), (0, 1)], num_modules=10)

# We've waited no time, so have to hit 90% to early exit
self.assertFalse(early_exit._should_exit(0))
self.assertFalse(early_exit._should_exit(5))
self.assertTrue(early_exit._should_exit(9))
self.assertEqual(early_exit.waited_time(), 0)

# We've waited 50% of the time, so only need to hit 50% to exit
mock_time.return_value = 5
self.assertFalse(early_exit._should_exit(0))
self.assertTrue(early_exit._should_exit(5))
self.assertEqual(early_exit.waited_time(), 5)

# We've waited 100% of the time, exit no matter what
mock_time.return_value = 10
self.assertTrue(early_exit._should_exit(0))
self.assertEqual(early_exit.waited_time(), 10)


if __name__ == '__main__':
absltest.main()
5 changes: 5 additions & 0 deletions compiler_opt/rl/feature_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,11 @@ def discard_feature(obs):
quantile, mean, std = quantile_map[obs_spec.name]

def normalization(obs):
# TODO(yundi): a temporary hard-coded solution for making test pass and
# submit regalloc code. Will have a big refactor in follow-up cls soon.
if obs_spec.name == 'progress':
obs = expand_dims_op(obs)
obs = tf.tile(obs, [1, 33])
expanded_obs = expand_dims_op(obs)
x = tf.cast(
tf.raw_ops.Bucketize(input=expanded_obs, boundaries=quantile),
Expand Down
30 changes: 14 additions & 16 deletions compiler_opt/rl/local_data_collector.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@

import collections
import random
import time
from typing import Callable, Iterator, List, Tuple, Iterable, Dict

from absl import logging
Expand Down Expand Up @@ -115,24 +114,23 @@ def collect_data(
results = self._schedule_jobs(policy_path, sampled_file_paths)

def wait_for_termination():
wait_seconds = 0
while True:
early_exit = data_collector.EarlyExitChecker(_DEADLINE_IN_SECONDS,
_WAIT_TERMINATION,
self._num_modules)

def get_num_finished_work():
finished_work = sum(res.ready() for res in results)
unfinised_work = len(results) - finished_work
unfinished_work = len(results) - finished_work
prev_unfinished_work = sum(
not res.ready() for _, res in self._unfinished_work)
total_unfinished_work = unfinised_work + prev_unfinished_work
for (data_collection_threshold,
wait_time_threshold) in _WAIT_TERMINATION:
if ((finished_work >= self._num_modules * data_collection_threshold)
and (wait_seconds >= _DEADLINE_IN_SECONDS * wait_time_threshold)):
if total_unfinished_work >= self._max_unfinished_tasks:
self._overloaded_workers_handler(total_unfinished_work)
break
else:
return wait_seconds
wait_seconds += 1
time.sleep(1)
# Handle overworked workers
total_unfinished_work = unfinished_work + prev_unfinished_work
if total_unfinished_work >= self._max_unfinished_tasks:
self._overloaded_workers_handler(total_unfinished_work)
return -1
return finished_work

return early_exit.wait(get_num_finished_work)

wait_seconds = wait_for_termination()

Expand Down
1 change: 0 additions & 1 deletion compiler_opt/rl/local_data_collector_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,6 @@ def parser(data_list):
overload_handler=overload_handler.handler)

collector.collect_data(policy_path='policy')
self.assertLen(overload_handler.counts, 0)
while [r for _, r in collector._unfinished_work if not r.ready()]:
time.sleep(1)

Expand Down
44 changes: 32 additions & 12 deletions compiler_opt/rl/regalloc/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,27 +21,47 @@
from tf_agents.trajectories import time_step


def get_num_registers():
return 33


# pylint: disable=g-complex-comprehension
@gin.configurable()
def get_regalloc_signature_spec():
"""Returns (time_step_spec, action_spec) for LLVM register allocation."""
num_registers = get_num_registers()

observation_spec = dict(
(key, tf.TensorSpec(dtype=tf.int64, shape=(), name=key))
for key in ('is_local_split', 'nr_defs_and_uses', 'nr_implicit_defs',
'nr_identity_copies', 'liverange_size',
'is_rematerializable'))
(key, tf.TensorSpec(dtype=tf.int64, shape=(num_registers), name=key))
for key in ('mask', 'is_hint', 'is_local', 'is_free'))
observation_spec.update(
dict((key, tf.TensorSpec(dtype=tf.float32, shape=(), name=key))
for key in ('weighed_reads', 'weighed_writes', 'weighed_indvars',
'hint_weights', 'start_bb_freq', 'end_bb_freq',
'hottest_bb_freq', 'weighed_read_writes')))
dict((key,
tensor_spec.BoundedTensorSpec(
dtype=tf.int64,
shape=(num_registers),
name=key,
minimum=0,
maximum=6)) for key in ('max_stage', 'min_stage')))
observation_spec.update(
dict((key,
tf.TensorSpec(dtype=tf.float32, shape=(num_registers), name=key))
for key in ('weighed_reads_by_max', 'weighed_writes_by_max',
'weighed_read_writes_by_max', 'weighed_indvars_by_max',
'hint_weights_by_max', 'start_bb_freq_by_max',
'end_bb_freq_by_max', 'hottest_bb_freq_by_max',
'liverange_size', 'use_def_density', 'nr_defs_and_uses',
'nr_broken_hints', 'nr_urgent', 'nr_rematerializable')))
observation_spec['progress'] = tensor_spec.BoundedTensorSpec(
dtype=tf.float32, shape=(), name='progress', minimum=0, maximum=1)

reward_spec = tf.TensorSpec(dtype=tf.float32, shape=(), name='reward')
time_step_spec = time_step.time_step_spec(observation_spec, reward_spec)

action_spec = tensor_spec.BoundedTensorSpec(
dtype=tf.float32,
dtype=tf.int64,
shape=(),
name='live_interval_weight',
minimum=-100,
maximum=20)
name='index_to_evict',
minimum=0,
maximum=num_registers - 1)

return time_step_spec, action_spec
Original file line number Diff line number Diff line change
@@ -1,25 +1,25 @@
import gin.tf.external_configurables
import compiler_opt.rl.gin_external_configurables
import compiler_opt.rl.regalloc.config
import compiler_opt.rl.regalloc_network
import tf_agents.agents.behavioral_cloning.behavioral_cloning_agent
import tf_agents.networks.actor_distribution_network

train_eval.get_signature_spec_fn=@config.get_regalloc_signature_spec
train_eval.agent_name='behavioral_cloning'
train_eval.num_iterations=200000
train_eval.num_iterations=10000
train_eval.batch_size=64
train_eval.train_sequence_length=1

get_observation_processing_layer_creator.quantile_file_dir='compiler_opt/rl/regalloc/vocab'
get_observation_processing_layer_creator.with_z_score_normalization = False

create_agent.policy_network = @actor_distribution_network.ActorDistributionNetwork
create_agent.policy_network = @regalloc_network.RegAllocNetwork

ActorDistributionNetwork.preprocessing_combiner=@tf.keras.layers.Concatenate()
ActorDistributionNetwork.fc_layer_params=(80, 40)
ActorDistributionNetwork.dropout_layer_params=(0.2, 0.2)
ActorDistributionNetwork.activation_fn=@tf.keras.activations.relu
NormalProjectionNetwork.mean_transform=None
RegAllocNetwork.preprocessing_combiner=@tf.keras.layers.Concatenate()
RegAllocNetwork.fc_layer_params=(80, 40)
RegAllocNetwork.dropout_layer_params=(0.2, 0.2)
RegAllocNetwork.activation_fn=@tf.keras.activations.relu

tf.train.AdamOptimizer.learning_rate = 0.001
tf.train.AdamOptimizer.epsilon = 0.0003125
Expand Down
Loading

0 comments on commit 4e301f5

Please sign in to comment.