blob: 9daf31ec6beeb22078b3b35883a3f27fdffee44a [file] [log] [blame]
Joe Gregorio652898b2011-05-02 21:07:43 -04001#!/usr/bin/python2.4
2# -*- coding: utf-8 -*-
3#
4# Copyright (C) 2010 Google Inc.
5#
6# Licensed under the Apache License, Version 2.0 (the "License");
7# you may not use this file except in compliance with the License.
8# You may obtain a copy of the License at
9#
10# http://www.apache.org/licenses/LICENSE-2.0
11#
12# Unless required by applicable law or agreed to in writing, software
13# distributed under the License is distributed on an "AS IS" BASIS,
14# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15# See the License for the specific language governing permissions and
16# limitations under the License.
17
18"""Simple command-line sample for the Google Prediction API
19
Joe Gregorio65826f92011-06-03 11:20:29 -040020Command-line application that trains on your input data. This sample does
21the same thing as the Hello Prediction! example. You might want to run
22the setup.sh script to load the sample data to Google Storage.
Joe Gregorio652898b2011-05-02 21:07:43 -040023
24Usage:
Joe Gregorio65826f92011-06-03 11:20:29 -040025 $ python prediction.py --object_name="bucket/object"
Joe Gregorio652898b2011-05-02 21:07:43 -040026
27You can also get help on all the command-line flags the program understands
28by running:
29
30 $ python prediction.py --help
31
32To get detailed log output run:
33
34 $ python prediction.py --logging_level=DEBUG
35"""
36
37__author__ = 'jcgregorio@google.com (Joe Gregorio)'
38
39import gflags
40import httplib2
41import logging
42import pprint
43import sys
44
45from apiclient.discovery import build
46from oauth2client.file import Storage
Joe Gregorio7d791212011-05-16 21:58:52 -070047from oauth2client.client import AccessTokenRefreshError
Joe Gregorio652898b2011-05-02 21:07:43 -040048from oauth2client.client import OAuth2WebServerFlow
49from oauth2client.tools import run
50
51FLAGS = gflags.FLAGS
52
53# Set up a Flow object to be used if we need to authenticate. This
54# sample uses OAuth 2.0, and we set up the OAuth2WebServerFlow with
55# the information it needs to authenticate. Note that it is called
56# the Web Server Flow, but it can also handle the flow for native
57# applications <http://code.google.com/apis/accounts/docs/OAuth2.html#IA>
58# The client_id client_secret are copied from the API Access tab on
59# the Google APIs Console <http://code.google.com/apis/console>. When
60# creating credentials for this application be sure to choose an Application
61# type of "Installed application".
62FLOW = OAuth2WebServerFlow(
63 client_id='433807057907.apps.googleusercontent.com',
64 client_secret='jigtZpMApkRxncxikFpR+SFg',
65 scope='https://www.googleapis.com/auth/prediction',
66 user_agent='prediction-cmdline-sample/1.0')
67
68# The gflags module makes defining command-line options easy for
69# applications. Run this program with the '--help' argument to see
70# all the flags that it understands.
71gflags.DEFINE_enum('logging_level', 'ERROR',
72 ['DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL'],
73 'Set the level of logging detail.')
74
Joe Gregorio65826f92011-06-03 11:20:29 -040075gflags.DEFINE_string('object_name',
76 None,
77 'Full Google Storage path of csv data (ex bucket/object)')
78
79gflags.MarkFlagAsRequired('object_name')
Joe Gregorio652898b2011-05-02 21:07:43 -040080
81def main(argv):
82 # Let the gflags module process the command-line arguments
83 try:
84 argv = FLAGS(argv)
85 except gflags.FlagsError, e:
86 print '%s\\nUsage: %s ARGS\\n%s' % (e, argv[0], FLAGS)
87 sys.exit(1)
88
89 # Set the logging according to the command-line flag
90 logging.getLogger().setLevel(getattr(logging, FLAGS.logging_level))
91
92 # If the Credentials don't exist or are invalid run through the native client
93 # flow. The Storage object will ensure that if successful the good
94 # Credentials will get written back to a file.
95 storage = Storage('prediction.dat')
96 credentials = storage.get()
97 if credentials is None or credentials.invalid:
98 credentials = run(FLOW, storage)
99
100 # Create an httplib2.Http object to handle our HTTP requests and authorize it
101 # with our good Credentials.
102 http = httplib2.Http()
103 http = credentials.authorize(http)
104
105 service = build("prediction", "v1.2", http=http)
106
Joe Gregorio7d791212011-05-16 21:58:52 -0700107 try:
Joe Gregorio652898b2011-05-02 21:07:43 -0400108
Joe Gregorio7d791212011-05-16 21:58:52 -0700109 # Start training on a data set
110 train = service.training()
Joe Gregorio65826f92011-06-03 11:20:29 -0400111 body = {'id' : FLAGS.object_name}
112 start = train.insert(body=body).execute()
Joe Gregorio652898b2011-05-02 21:07:43 -0400113
Joe Gregorio7d791212011-05-16 21:58:52 -0700114 print 'Started training'
115 pprint.pprint(start)
Joe Gregorio652898b2011-05-02 21:07:43 -0400116
Joe Gregorio7d791212011-05-16 21:58:52 -0700117 import time
118 # Wait for the training to complete
119 while True:
Joe Gregorio65826f92011-06-03 11:20:29 -0400120 status = train.get(data=FLAGS.object_name).execute()
Joe Gregorio7d791212011-05-16 21:58:52 -0700121 pprint.pprint(status)
122 if 'RUNNING' != status['trainingStatus']:
123 break
124 print 'Waiting for training to complete.'
125 time.sleep(10)
126 print 'Training is complete'
127
128 # Now make a prediction using that training
129 body = {'input': {'csvInstance': ["mucho bueno"]}}
Joe Gregorio65826f92011-06-03 11:20:29 -0400130 prediction = service.predict(body=body, data=FLAGS.object_name).execute()
Joe Gregorio7d791212011-05-16 21:58:52 -0700131 print 'The prediction is:'
132 pprint.pprint(prediction)
Joe Gregorio652898b2011-05-02 21:07:43 -0400133
134
Joe Gregorio7d791212011-05-16 21:58:52 -0700135 except AccessTokenRefreshError:
136 print ("The credentials have been revoked or expired, please re-run"
137 "the application to re-authorize")
Joe Gregorio652898b2011-05-02 21:07:43 -0400138
139if __name__ == '__main__':
140 main(sys.argv)