blob: 50c1ebad445b3136f46d68bf861860448318c2bf [file] [log] [blame]
Joe Gregorio652898b2011-05-02 21:07:43 -04001#!/usr/bin/python2.4
2# -*- coding: utf-8 -*-
3#
Craig Citro751b7fb2014-09-23 11:20:38 -07004# Copyright 2014 Google Inc. All Rights Reserved.
Joe Gregorio652898b2011-05-02 21:07:43 -04005#
6# Licensed under the Apache License, Version 2.0 (the "License");
7# you may not use this file except in compliance with the License.
8# You may obtain a copy of the License at
9#
10# http://www.apache.org/licenses/LICENSE-2.0
11#
12# Unless required by applicable law or agreed to in writing, software
13# distributed under the License is distributed on an "AS IS" BASIS,
14# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15# See the License for the specific language governing permissions and
16# limitations under the License.
17
18"""Simple command-line sample for the Google Prediction API
19
Joe Gregorio65826f92011-06-03 11:20:29 -040020Command-line application that trains on your input data. This sample does
21the same thing as the Hello Prediction! example. You might want to run
22the setup.sh script to load the sample data to Google Storage.
Joe Gregorio652898b2011-05-02 21:07:43 -040023
24Usage:
Antoine Picard7ba3c3f2014-05-09 14:55:58 -070025 $ python prediction.py "bucket/object" "model_id" "project_id"
Joe Gregorio652898b2011-05-02 21:07:43 -040026
27You can also get help on all the command-line flags the program understands
28by running:
29
30 $ python prediction.py --help
31
32To get detailed log output run:
33
34 $ python prediction.py --logging_level=DEBUG
35"""
36
Joe Gregorio968a9582012-03-07 14:52:52 -050037__author__ = ('jcgregorio@google.com (Joe Gregorio), '
38 'marccohen@google.com (Marc Cohen)')
Joe Gregorio652898b2011-05-02 21:07:43 -040039
Joe Gregorioe8391152013-06-28 01:30:57 -040040import argparse
Joe Gregorio968a9582012-03-07 14:52:52 -050041import os
Joe Gregorio652898b2011-05-02 21:07:43 -040042import pprint
43import sys
Joe Gregorio968a9582012-03-07 14:52:52 -050044import time
Joe Gregorio652898b2011-05-02 21:07:43 -040045
Antoine Picard7ba3c3f2014-05-09 14:55:58 -070046from apiclient import discovery
47from apiclient import sample_tools
Joe Gregorioe8391152013-06-28 01:30:57 -040048from oauth2client import client
Joe Gregorio652898b2011-05-02 21:07:43 -040049
Joe Gregorio968a9582012-03-07 14:52:52 -050050
51# Time to wait (in seconds) between successive checks of training status.
52SLEEP_TIME = 10
53
Joe Gregorioe8391152013-06-28 01:30:57 -040054
55# Declare command-line flags.
56argparser = argparse.ArgumentParser(add_help=False)
57argparser.add_argument('object_name',
Antoine Picard7ba3c3f2014-05-09 14:55:58 -070058 help='Full Google Storage path of csv data (ex bucket/object)')
59argparser.add_argument('model_id',
60 help='Model Id of your choosing to name trained model')
61argparser.add_argument('project_id',
62 help='Model Id of your choosing to name trained model')
Joe Gregorioe8391152013-06-28 01:30:57 -040063
64
Joe Gregorio968a9582012-03-07 14:52:52 -050065def print_header(line):
66 '''Format and print header block sized to length of line'''
67 header_str = '='
68 header_line = header_str * len(line)
69 print '\n' + header_line
70 print line
Joe Gregorioe8391152013-06-28 01:30:57 -040071 print header_line
72
73
Joe Gregorio652898b2011-05-02 21:07:43 -040074def main(argv):
Antoine Picard7ba3c3f2014-05-09 14:55:58 -070075 # If you previously ran this app with an earlier version of the API
76 # or if you change the list of scopes below, revoke your app's permission
77 # here: https://accounts.google.com/IssuedAuthSubTokens
78 # Then re-run the app to re-authorize it.
Joe Gregorioe8391152013-06-28 01:30:57 -040079 service, flags = sample_tools.init(
Antoine Picard7ba3c3f2014-05-09 14:55:58 -070080 argv, 'prediction', 'v1.6', __doc__, __file__, parents=[argparser],
81 scope=(
82 'https://www.googleapis.com/auth/prediction',
83 'https://www.googleapis.com/auth/devstorage.read_only'))
Joe Gregorio652898b2011-05-02 21:07:43 -040084
Joe Gregorio7d791212011-05-16 21:58:52 -070085 try:
Joe Gregorio968a9582012-03-07 14:52:52 -050086 # Get access to the Prediction API.
Joe Gregorio968a9582012-03-07 14:52:52 -050087 papi = service.trainedmodels()
Joe Gregorio652898b2011-05-02 21:07:43 -040088
Joe Gregorio968a9582012-03-07 14:52:52 -050089 # List models.
90 print_header('Fetching list of first ten models')
Antoine Picard7ba3c3f2014-05-09 14:55:58 -070091 result = papi.list(maxResults=10, project=flags.project_id).execute()
Joe Gregorio968a9582012-03-07 14:52:52 -050092 print 'List results:'
93 pprint.pprint(result)
94
95 # Start training request on a data set.
96 print_header('Submitting model training request')
Antoine Picard7ba3c3f2014-05-09 14:55:58 -070097 body = {'id': flags.model_id, 'storageDataLocation': flags.object_name}
98 start = papi.insert(body=body, project=flags.project_id).execute()
Joe Gregorio968a9582012-03-07 14:52:52 -050099 print 'Training results:'
Joe Gregorio7d791212011-05-16 21:58:52 -0700100 pprint.pprint(start)
Joe Gregorioe8391152013-06-28 01:30:57 -0400101
Joe Gregorio968a9582012-03-07 14:52:52 -0500102 # Wait for the training to complete.
103 print_header('Waiting for training to complete')
Joe Gregorio7d791212011-05-16 21:58:52 -0700104 while True:
Antoine Picard7ba3c3f2014-05-09 14:55:58 -0700105 status = papi.get(id=flags.model_id, project=flags.project_id).execute()
Joe Gregorio968a9582012-03-07 14:52:52 -0500106 state = status['trainingStatus']
107 print 'Training state: ' + state
108 if state == 'DONE':
Joe Gregorio7d791212011-05-16 21:58:52 -0700109 break
Joe Gregorio968a9582012-03-07 14:52:52 -0500110 elif state == 'RUNNING':
111 time.sleep(SLEEP_TIME)
112 continue
113 else:
114 raise Exception('Training Error: ' + state)
Joe Gregorioe8391152013-06-28 01:30:57 -0400115
Joe Gregorio968a9582012-03-07 14:52:52 -0500116 # Job has completed.
117 print 'Training completed:'
118 pprint.pprint(status)
119 break
Robert Kaplow49cd5f82011-08-02 13:50:04 -0400120
Joe Gregorio968a9582012-03-07 14:52:52 -0500121 # Describe model.
122 print_header('Fetching model description')
Antoine Picard7ba3c3f2014-05-09 14:55:58 -0700123 result = papi.analyze(id=flags.model_id, project=flags.project_id).execute()
Joe Gregorio968a9582012-03-07 14:52:52 -0500124 print 'Analyze results:'
125 pprint.pprint(result)
Joe Gregorio7d791212011-05-16 21:58:52 -0700126
Antoine Picard7ba3c3f2014-05-09 14:55:58 -0700127 # Make some predictions using the newly trained model.
128 print_header('Making some predictions')
129 for sample_text in ['mucho bueno', 'bonjour, mon cher ami']:
130 body = {'input': {'csvInstance': [sample_text]}}
131 result = papi.predict(
132 body=body, id=flags.model_id, project=flags.project_id).execute()
133 print 'Prediction results for "%s"...' % sample_text
134 pprint.pprint(result)
Joe Gregorio652898b2011-05-02 21:07:43 -0400135
Joe Gregorio968a9582012-03-07 14:52:52 -0500136 # Delete model.
137 print_header('Deleting model')
Antoine Picard7ba3c3f2014-05-09 14:55:58 -0700138 result = papi.delete(id=flags.model_id, project=flags.project_id).execute()
Joe Gregorio968a9582012-03-07 14:52:52 -0500139 print 'Model deleted.'
Joe Gregorio652898b2011-05-02 21:07:43 -0400140
Joe Gregorioe8391152013-06-28 01:30:57 -0400141 except client.AccessTokenRefreshError:
Antoine Picard7ba3c3f2014-05-09 14:55:58 -0700142 print ('The credentials have been revoked or expired, please re-run '
143 'the application to re-authorize.')
Joe Gregorio652898b2011-05-02 21:07:43 -0400144
Joe Gregorioe8391152013-06-28 01:30:57 -0400145
Joe Gregorio652898b2011-05-02 21:07:43 -0400146if __name__ == '__main__':
147 main(sys.argv)