Skip to content
This repository has been archived by the owner on Sep 26, 2020. It is now read-only.

Commit

Permalink
Add upload training results
Browse files Browse the repository at this point in the history
  • Loading branch information
Octogonapus committed Jan 31, 2020
1 parent 5b9ae9d commit 35eb053
Showing 1 changed file with 26 additions and 54 deletions.
80 changes: 26 additions & 54 deletions axon/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -603,34 +603,6 @@ def impl_download_untrained_model(model_path, bucket_name, region):
print("Downloaded from: {}\n".format(key))


def impl_upload_trained_model(model_path, bucket_name, region):
"""
Uploads an trained model to S3.
:param model_path: The file path to the model to upload, ending with the name of the model.
:param bucket_name: The S3 bucket name.
:param region: The region, or `None` to pull the region from the environment.
"""
client = make_client("s3", region)
key = "axon-trained-models/" + os.path.basename(model_path)
client.upload_file(model_path, bucket_name, key)
print("Uploaded to: {}\n".format(key))


def impl_download_trained_model(model_path, bucket_name, region):
"""
Downloads an trained model from S3.
:param model_path: The file path to download to, ending with the name of the model.
:param bucket_name: The S3 bucket name.
:param region: The region, or `None` to pull the region from the environment.
"""
client = make_client("s3", region)
key = "axon-trained-models/" + os.path.basename(model_path)
client.download_file(bucket_name, key, model_path)
print("Downloaded from: {}\n".format(key))


def impl_download_training_script(script_path, bucket_name, region):
"""
Downloads a training script from S3.
Expand Down Expand Up @@ -723,6 +695,16 @@ def impl_remove_heartbeat(job_id, bucket_name, region):
print("Removed heartbeat file in: {}\n".format(remote_path))


def impl_upload_training_results(job_id, output_dir, bucket_name, region):
client = make_client("s3", region)
files_to_upload = [os.path.join(output_dir, it) for it in os.listdir(output_dir)]
files_to_upload = [it for it in files_to_upload if os.path.isfile(it)]
for elem in files_to_upload:
key = "axon-training-results/{}/{}".format(job_id, os.path.basename(elem))
client.upload_file(elem, bucket_name, key)
print("Uploaded to: {}\n".format(key))


def create_progress_prefix(job_id):
return "axon-training-progress/{}".format(job_id)

Expand Down Expand Up @@ -848,32 +830,6 @@ def download_untrained_model(model_path, region):
impl_download_untrained_model(model_path, ensure_s3_bucket(region), region)


@cli.command(name="upload-trained-model")
@click.argument("model-path")
@click.option("--region", help="The region to connect to.",
type=click.Choice(region_choices))
def upload_trained_model(model_path, region):
"""
Uploads a trained model from a local file.
MODEL_PATH The path to the model to upload, ending with the name of the model.
"""
impl_upload_trained_model(model_path, ensure_s3_bucket(region), region)


@cli.command(name="download-trained-model")
@click.argument("model-path")
@click.option("--region", help="The region to connect to.",
type=click.Choice(region_choices))
def download_trained_model(model_path, region):
"""
Downloads a trained model to a local file.
MODEL_PATH The path to download the model to, ending with the name of the model.
"""
impl_download_trained_model(model_path, ensure_s3_bucket(region), region)


@cli.command(name="download-training-script")
@click.argument("script-path")
@click.option("--region", help="The region to connect to.",
Expand Down Expand Up @@ -955,3 +911,19 @@ def remove_heartbeat(job_id, region):
JOB_ID The unique Job ID.
"""
impl_remove_heartbeat(job_id, ensure_s3_bucket(region), region)


@cli.command(name="upload-training-results")
@click.argument("job-id")
@click.argument("output-dir")
@click.option("--region", help="The region to connect to.",
type=click.Choice(region_choices))
def upload_training_results(job_id, output_dir, region):
"""
Uploads the results from running a training script.
JOB_ID The unique Job ID.
OUTPUT_DIR The directory containing the results.
"""
impl_upload_training_results(job_id, output_dir, ensure_s3_bucket(region), region)

0 comments on commit 35eb053

Please sign in to comment.