blob: a130f2a29b14733964fd6341bbd9b7f5d1842ca6 [file] [log] [blame]
Joe Gregorio652898b2011-05-02 21:07:43 -04001#!/usr/bin/python2.4
2# -*- coding: utf-8 -*-
3#
4# Copyright (C) 2010 Google Inc.
5#
6# Licensed under the Apache License, Version 2.0 (the "License");
7# you may not use this file except in compliance with the License.
8# You may obtain a copy of the License at
9#
10# http://www.apache.org/licenses/LICENSE-2.0
11#
12# Unless required by applicable law or agreed to in writing, software
13# distributed under the License is distributed on an "AS IS" BASIS,
14# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15# See the License for the specific language governing permissions and
16# limitations under the License.
17
18"""Simple command-line sample for the Google Prediction API
19
Joe Gregorio65826f92011-06-03 11:20:29 -040020Command-line application that trains on your input data. This sample does
21the same thing as the Hello Prediction! example. You might want to run
22the setup.sh script to load the sample data to Google Storage.
Joe Gregorio652898b2011-05-02 21:07:43 -040023
24Usage:
Joe Gregorio968a9582012-03-07 14:52:52 -050025 $ python prediction.py --object_name="bucket/object" --id="model_id"
Joe Gregorio652898b2011-05-02 21:07:43 -040026
27You can also get help on all the command-line flags the program understands
28by running:
29
30 $ python prediction.py --help
31
32To get detailed log output run:
33
34 $ python prediction.py --logging_level=DEBUG
35"""
36
Joe Gregorio968a9582012-03-07 14:52:52 -050037__author__ = ('jcgregorio@google.com (Joe Gregorio), '
38 'marccohen@google.com (Marc Cohen)')
Joe Gregorio652898b2011-05-02 21:07:43 -040039
Joe Gregorioe8391152013-06-28 01:30:57 -040040import argparse
Joe Gregorio968a9582012-03-07 14:52:52 -050041import os
Joe Gregorio652898b2011-05-02 21:07:43 -040042import pprint
43import sys
Joe Gregorio968a9582012-03-07 14:52:52 -050044import time
Joe Gregorio652898b2011-05-02 21:07:43 -040045
John Asmuth864311d2014-04-24 15:46:08 -040046from googleapiclient import discovery
47from googleapiclient import sample_tools
Joe Gregorioe8391152013-06-28 01:30:57 -040048from oauth2client import client
Joe Gregorio652898b2011-05-02 21:07:43 -040049
Joe Gregorio968a9582012-03-07 14:52:52 -050050
51# Time to wait (in seconds) between successive checks of training status.
52SLEEP_TIME = 10
53
Joe Gregorioe8391152013-06-28 01:30:57 -040054
55# Declare command-line flags.
56argparser = argparse.ArgumentParser(add_help=False)
57argparser.add_argument('object_name',
58 help='Full Google Storage path of csv data (ex bucket/object)')
59argparser.add_argument('id',
60 help='Model Id of your choosing to name trained model')
61
62
Joe Gregorio968a9582012-03-07 14:52:52 -050063def print_header(line):
64 '''Format and print header block sized to length of line'''
65 header_str = '='
66 header_line = header_str * len(line)
67 print '\n' + header_line
68 print line
Joe Gregorioe8391152013-06-28 01:30:57 -040069 print header_line
70
71
Joe Gregorio652898b2011-05-02 21:07:43 -040072def main(argv):
Joe Gregorioe8391152013-06-28 01:30:57 -040073 service, flags = sample_tools.init(
74 argv, 'prediction', 'v1.5', __doc__, __file__, parents=[argparser],
75 scope='https://www.googleapis.com/auth/prediction')
Joe Gregorio652898b2011-05-02 21:07:43 -040076
Joe Gregorio7d791212011-05-16 21:58:52 -070077 try:
Joe Gregorio968a9582012-03-07 14:52:52 -050078 # Get access to the Prediction API.
Joe Gregorio968a9582012-03-07 14:52:52 -050079 papi = service.trainedmodels()
Joe Gregorio652898b2011-05-02 21:07:43 -040080
Joe Gregorio968a9582012-03-07 14:52:52 -050081 # List models.
82 print_header('Fetching list of first ten models')
83 result = papi.list(maxResults=10).execute()
84 print 'List results:'
85 pprint.pprint(result)
86
87 # Start training request on a data set.
88 print_header('Submitting model training request')
Joe Gregorioe8391152013-06-28 01:30:57 -040089 body = {'id': flags.id, 'storageDataLocation': flags.object_name}
Joe Gregorio968a9582012-03-07 14:52:52 -050090 start = papi.insert(body=body).execute()
91 print 'Training results:'
Joe Gregorio7d791212011-05-16 21:58:52 -070092 pprint.pprint(start)
Joe Gregorioe8391152013-06-28 01:30:57 -040093
Joe Gregorio968a9582012-03-07 14:52:52 -050094 # Wait for the training to complete.
95 print_header('Waiting for training to complete')
Joe Gregorio7d791212011-05-16 21:58:52 -070096 while True:
Joe Gregorioe8391152013-06-28 01:30:57 -040097 status = papi.get(id=flags.id).execute()
Joe Gregorio968a9582012-03-07 14:52:52 -050098 state = status['trainingStatus']
99 print 'Training state: ' + state
100 if state == 'DONE':
Joe Gregorio7d791212011-05-16 21:58:52 -0700101 break
Joe Gregorio968a9582012-03-07 14:52:52 -0500102 elif state == 'RUNNING':
103 time.sleep(SLEEP_TIME)
104 continue
105 else:
106 raise Exception('Training Error: ' + state)
Joe Gregorioe8391152013-06-28 01:30:57 -0400107
Joe Gregorio968a9582012-03-07 14:52:52 -0500108 # Job has completed.
109 print 'Training completed:'
110 pprint.pprint(status)
111 break
Robert Kaplow49cd5f82011-08-02 13:50:04 -0400112
Joe Gregorio968a9582012-03-07 14:52:52 -0500113 # Describe model.
114 print_header('Fetching model description')
Joe Gregorioe8391152013-06-28 01:30:57 -0400115 result = papi.analyze(id=flags.id).execute()
Joe Gregorio968a9582012-03-07 14:52:52 -0500116 print 'Analyze results:'
117 pprint.pprint(result)
Joe Gregorio7d791212011-05-16 21:58:52 -0700118
Joe Gregorio968a9582012-03-07 14:52:52 -0500119 # Make a prediction using the newly trained model.
120 print_header('Making a prediction')
Joe Gregorio7d791212011-05-16 21:58:52 -0700121 body = {'input': {'csvInstance': ["mucho bueno"]}}
Joe Gregorioe8391152013-06-28 01:30:57 -0400122 result = papi.predict(body=body, id=flags.id).execute()
Joe Gregorio968a9582012-03-07 14:52:52 -0500123 print 'Prediction results...'
124 pprint.pprint(result)
Joe Gregorio652898b2011-05-02 21:07:43 -0400125
Joe Gregorio968a9582012-03-07 14:52:52 -0500126 # Delete model.
127 print_header('Deleting model')
Joe Gregorioe8391152013-06-28 01:30:57 -0400128 result = papi.delete(id=flags.id).execute()
Joe Gregorio968a9582012-03-07 14:52:52 -0500129 print 'Model deleted.'
Joe Gregorio652898b2011-05-02 21:07:43 -0400130
Joe Gregorioe8391152013-06-28 01:30:57 -0400131 except client.AccessTokenRefreshError:
Joe Gregorio7d791212011-05-16 21:58:52 -0700132 print ("The credentials have been revoked or expired, please re-run"
133 "the application to re-authorize")
Joe Gregorio652898b2011-05-02 21:07:43 -0400134
Joe Gregorioe8391152013-06-28 01:30:57 -0400135
Joe Gregorio652898b2011-05-02 21:07:43 -0400136if __name__ == '__main__':
137 main(sys.argv)