blob: 958d79cdef698924211304ec9d5c46b1b00237ea [file] [log] [blame]
Joe Gregorio652898b2011-05-02 21:07:43 -04001#!/usr/bin/python2.4
2# -*- coding: utf-8 -*-
3#
4# Copyright (C) 2010 Google Inc.
5#
6# Licensed under the Apache License, Version 2.0 (the "License");
7# you may not use this file except in compliance with the License.
8# You may obtain a copy of the License at
9#
10# http://www.apache.org/licenses/LICENSE-2.0
11#
12# Unless required by applicable law or agreed to in writing, software
13# distributed under the License is distributed on an "AS IS" BASIS,
14# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15# See the License for the specific language governing permissions and
16# limitations under the License.
17
18"""Simple command-line sample for the Google Prediction API
19
Joe Gregorio65826f92011-06-03 11:20:29 -040020Command-line application that trains on your input data. This sample does
21the same thing as the Hello Prediction! example. You might want to run
22the setup.sh script to load the sample data to Google Storage.
Joe Gregorio652898b2011-05-02 21:07:43 -040023
24Usage:
Joe Gregorio968a9582012-03-07 14:52:52 -050025 $ python prediction.py --object_name="bucket/object" --id="model_id"
Joe Gregorio652898b2011-05-02 21:07:43 -040026
27You can also get help on all the command-line flags the program understands
28by running:
29
30 $ python prediction.py --help
31
32To get detailed log output run:
33
34 $ python prediction.py --logging_level=DEBUG
35"""
36
Joe Gregorio968a9582012-03-07 14:52:52 -050037__author__ = ('jcgregorio@google.com (Joe Gregorio), '
38 'marccohen@google.com (Marc Cohen)')
Joe Gregorio652898b2011-05-02 21:07:43 -040039
Robert Kaplow49cd5f82011-08-02 13:50:04 -040040import apiclient.errors
Joe Gregorio652898b2011-05-02 21:07:43 -040041import gflags
42import httplib2
43import logging
Joe Gregorio968a9582012-03-07 14:52:52 -050044import os
Joe Gregorio652898b2011-05-02 21:07:43 -040045import pprint
46import sys
Joe Gregorio968a9582012-03-07 14:52:52 -050047import time
Joe Gregorio652898b2011-05-02 21:07:43 -040048
49from apiclient.discovery import build
50from oauth2client.file import Storage
Joe Gregorio7d791212011-05-16 21:58:52 -070051from oauth2client.client import AccessTokenRefreshError
Joe Gregorio968a9582012-03-07 14:52:52 -050052from oauth2client.client import flow_from_clientsecrets
Joe Gregorio652898b2011-05-02 21:07:43 -040053from oauth2client.tools import run
54
55FLAGS = gflags.FLAGS
56
Joe Gregorio968a9582012-03-07 14:52:52 -050057# CLIENT_SECRETS, name of a file containing the OAuth 2.0 information for this
58# application, including client_id and client_secret, which are found
59# on the API Access tab on the Google APIs
60# Console <http://code.google.com/apis/console>
61CLIENT_SECRETS = 'client_secrets.json'
62
63# Helpful message to display in the browser if the CLIENT_SECRETS file
64# is missing.
65MISSING_CLIENT_SECRETS_MESSAGE = """
66WARNING: Please configure OAuth 2.0
67
68To make this sample run you will need to populate the client_secrets.json file
69found at:
70
71 %s
72
73with information from the APIs Console <https://code.google.com/apis/console>.
74
75""" % os.path.join(os.path.dirname(__file__), CLIENT_SECRETS)
76
77# Set up a Flow object to be used if we need to authenticate.
78FLOW = flow_from_clientsecrets(CLIENT_SECRETS,
79 scope='https://www.googleapis.com/auth/prediction',
80 message=MISSING_CLIENT_SECRETS_MESSAGE)
Joe Gregorio652898b2011-05-02 21:07:43 -040081
82# The gflags module makes defining command-line options easy for
83# applications. Run this program with the '--help' argument to see
84# all the flags that it understands.
85gflags.DEFINE_enum('logging_level', 'ERROR',
Joe Gregorio968a9582012-03-07 14:52:52 -050086 ['DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL'],
87 'Set the level of logging detail.')
Joe Gregorio652898b2011-05-02 21:07:43 -040088
Joe Gregorio65826f92011-06-03 11:20:29 -040089gflags.DEFINE_string('object_name',
90 None,
91 'Full Google Storage path of csv data (ex bucket/object)')
Joe Gregorio65826f92011-06-03 11:20:29 -040092gflags.MarkFlagAsRequired('object_name')
Joe Gregorio652898b2011-05-02 21:07:43 -040093
Joe Gregorio968a9582012-03-07 14:52:52 -050094gflags.DEFINE_string('id',
95 None,
96 'Model Id of your choosing to name trained model')
97gflags.MarkFlagAsRequired('id')
98
99# Time to wait (in seconds) between successive checks of training status.
100SLEEP_TIME = 10
101
102def print_header(line):
103 '''Format and print header block sized to length of line'''
104 header_str = '='
105 header_line = header_str * len(line)
106 print '\n' + header_line
107 print line
108 print header_line
109
Joe Gregorio652898b2011-05-02 21:07:43 -0400110def main(argv):
Joe Gregorio968a9582012-03-07 14:52:52 -0500111 # Let the gflags module process the command-line arguments.
Joe Gregorio652898b2011-05-02 21:07:43 -0400112 try:
113 argv = FLAGS(argv)
114 except gflags.FlagsError, e:
115 print '%s\\nUsage: %s ARGS\\n%s' % (e, argv[0], FLAGS)
116 sys.exit(1)
117
118 # Set the logging according to the command-line flag
119 logging.getLogger().setLevel(getattr(logging, FLAGS.logging_level))
120
121 # If the Credentials don't exist or are invalid run through the native client
122 # flow. The Storage object will ensure that if successful the good
123 # Credentials will get written back to a file.
124 storage = Storage('prediction.dat')
125 credentials = storage.get()
126 if credentials is None or credentials.invalid:
127 credentials = run(FLOW, storage)
128
129 # Create an httplib2.Http object to handle our HTTP requests and authorize it
130 # with our good Credentials.
131 http = httplib2.Http()
132 http = credentials.authorize(http)
133
Joe Gregorio7d791212011-05-16 21:58:52 -0700134 try:
Joe Gregorio652898b2011-05-02 21:07:43 -0400135
Joe Gregorio968a9582012-03-07 14:52:52 -0500136 # Get access to the Prediction API.
137 service = build("prediction", "v1.5", http=http)
138 papi = service.trainedmodels()
Joe Gregorio652898b2011-05-02 21:07:43 -0400139
Joe Gregorio968a9582012-03-07 14:52:52 -0500140 # List models.
141 print_header('Fetching list of first ten models')
142 result = papi.list(maxResults=10).execute()
143 print 'List results:'
144 pprint.pprint(result)
145
146 # Start training request on a data set.
147 print_header('Submitting model training request')
148 body = {'id': FLAGS.id, 'storageDataLocation': FLAGS.object_name}
149 start = papi.insert(body=body).execute()
150 print 'Training results:'
Joe Gregorio7d791212011-05-16 21:58:52 -0700151 pprint.pprint(start)
Joe Gregorio968a9582012-03-07 14:52:52 -0500152
153 # Wait for the training to complete.
154 print_header('Waiting for training to complete')
Joe Gregorio7d791212011-05-16 21:58:52 -0700155 while True:
Joe Gregorio968a9582012-03-07 14:52:52 -0500156 status = papi.get(id=FLAGS.id).execute()
157 state = status['trainingStatus']
158 print 'Training state: ' + state
159 if state == 'DONE':
Joe Gregorio7d791212011-05-16 21:58:52 -0700160 break
Joe Gregorio968a9582012-03-07 14:52:52 -0500161 elif state == 'RUNNING':
162 time.sleep(SLEEP_TIME)
163 continue
164 else:
165 raise Exception('Training Error: ' + state)
166
167 # Job has completed.
168 print 'Training completed:'
169 pprint.pprint(status)
170 break
Robert Kaplow49cd5f82011-08-02 13:50:04 -0400171
Joe Gregorio968a9582012-03-07 14:52:52 -0500172 # Describe model.
173 print_header('Fetching model description')
174 result = papi.analyze(id=FLAGS.id).execute()
175 print 'Analyze results:'
176 pprint.pprint(result)
Joe Gregorio7d791212011-05-16 21:58:52 -0700177
Joe Gregorio968a9582012-03-07 14:52:52 -0500178 # Make a prediction using the newly trained model.
179 print_header('Making a prediction')
Joe Gregorio7d791212011-05-16 21:58:52 -0700180 body = {'input': {'csvInstance': ["mucho bueno"]}}
Joe Gregorio968a9582012-03-07 14:52:52 -0500181 result = papi.predict(body=body, id=FLAGS.id).execute()
182 print 'Prediction results...'
183 pprint.pprint(result)
Joe Gregorio652898b2011-05-02 21:07:43 -0400184
Joe Gregorio968a9582012-03-07 14:52:52 -0500185 # Delete model.
186 print_header('Deleting model')
187 result = papi.delete(id=FLAGS.id).execute()
188 print 'Model deleted.'
Joe Gregorio652898b2011-05-02 21:07:43 -0400189
Joe Gregorio7d791212011-05-16 21:58:52 -0700190 except AccessTokenRefreshError:
191 print ("The credentials have been revoked or expired, please re-run"
192 "the application to re-authorize")
Joe Gregorio652898b2011-05-02 21:07:43 -0400193
194if __name__ == '__main__':
195 main(sys.argv)