blob: 12270191e1fd7afdc81cc1eb606adf29d73c148c [file] [log] [blame]
Joe Gregorio652898b2011-05-02 21:07:43 -04001#!/usr/bin/python2.4
2# -*- coding: utf-8 -*-
3#
4# Copyright (C) 2010 Google Inc.
5#
6# Licensed under the Apache License, Version 2.0 (the "License");
7# you may not use this file except in compliance with the License.
8# You may obtain a copy of the License at
9#
10# http://www.apache.org/licenses/LICENSE-2.0
11#
12# Unless required by applicable law or agreed to in writing, software
13# distributed under the License is distributed on an "AS IS" BASIS,
14# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15# See the License for the specific language governing permissions and
16# limitations under the License.
17
18"""Simple command-line sample for the Google Prediction API
19
Joe Gregorio65826f92011-06-03 11:20:29 -040020Command-line application that trains on your input data. This sample does
21the same thing as the Hello Prediction! example. You might want to run
22the setup.sh script to load the sample data to Google Storage.
Joe Gregorio652898b2011-05-02 21:07:43 -040023
24Usage:
Joe Gregorio65826f92011-06-03 11:20:29 -040025 $ python prediction.py --object_name="bucket/object"
Joe Gregorio652898b2011-05-02 21:07:43 -040026
27You can also get help on all the command-line flags the program understands
28by running:
29
30 $ python prediction.py --help
31
32To get detailed log output run:
33
34 $ python prediction.py --logging_level=DEBUG
35"""
36
37__author__ = 'jcgregorio@google.com (Joe Gregorio)'
38
Robert Kaplow49cd5f82011-08-02 13:50:04 -040039import apiclient.errors
Joe Gregorio652898b2011-05-02 21:07:43 -040040import gflags
41import httplib2
42import logging
43import pprint
44import sys
45
46from apiclient.discovery import build
47from oauth2client.file import Storage
Joe Gregorio7d791212011-05-16 21:58:52 -070048from oauth2client.client import AccessTokenRefreshError
Joe Gregorio652898b2011-05-02 21:07:43 -040049from oauth2client.client import OAuth2WebServerFlow
50from oauth2client.tools import run
51
52FLAGS = gflags.FLAGS
53
54# Set up a Flow object to be used if we need to authenticate. This
55# sample uses OAuth 2.0, and we set up the OAuth2WebServerFlow with
56# the information it needs to authenticate. Note that it is called
57# the Web Server Flow, but it can also handle the flow for native
58# applications <http://code.google.com/apis/accounts/docs/OAuth2.html#IA>
59# The client_id client_secret are copied from the API Access tab on
60# the Google APIs Console <http://code.google.com/apis/console>. When
61# creating credentials for this application be sure to choose an Application
62# type of "Installed application".
63FLOW = OAuth2WebServerFlow(
64 client_id='433807057907.apps.googleusercontent.com',
65 client_secret='jigtZpMApkRxncxikFpR+SFg',
66 scope='https://www.googleapis.com/auth/prediction',
67 user_agent='prediction-cmdline-sample/1.0')
68
69# The gflags module makes defining command-line options easy for
70# applications. Run this program with the '--help' argument to see
71# all the flags that it understands.
72gflags.DEFINE_enum('logging_level', 'ERROR',
73 ['DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL'],
74 'Set the level of logging detail.')
75
Joe Gregorio65826f92011-06-03 11:20:29 -040076gflags.DEFINE_string('object_name',
77 None,
78 'Full Google Storage path of csv data (ex bucket/object)')
79
80gflags.MarkFlagAsRequired('object_name')
Joe Gregorio652898b2011-05-02 21:07:43 -040081
82def main(argv):
83 # Let the gflags module process the command-line arguments
84 try:
85 argv = FLAGS(argv)
86 except gflags.FlagsError, e:
87 print '%s\\nUsage: %s ARGS\\n%s' % (e, argv[0], FLAGS)
88 sys.exit(1)
89
90 # Set the logging according to the command-line flag
91 logging.getLogger().setLevel(getattr(logging, FLAGS.logging_level))
92
93 # If the Credentials don't exist or are invalid run through the native client
94 # flow. The Storage object will ensure that if successful the good
95 # Credentials will get written back to a file.
96 storage = Storage('prediction.dat')
97 credentials = storage.get()
98 if credentials is None or credentials.invalid:
99 credentials = run(FLOW, storage)
100
101 # Create an httplib2.Http object to handle our HTTP requests and authorize it
102 # with our good Credentials.
103 http = httplib2.Http()
104 http = credentials.authorize(http)
105
Robert Kaplow49cd5f82011-08-02 13:50:04 -0400106 service = build("prediction", "v1.3", http=http)
Joe Gregorio652898b2011-05-02 21:07:43 -0400107
Joe Gregorio7d791212011-05-16 21:58:52 -0700108 try:
Joe Gregorio652898b2011-05-02 21:07:43 -0400109
Joe Gregorio7d791212011-05-16 21:58:52 -0700110 # Start training on a data set
111 train = service.training()
Joe Gregorio562b7312011-09-15 09:06:38 -0400112 body = {'id': FLAGS.object_name}
Joe Gregorio65826f92011-06-03 11:20:29 -0400113 start = train.insert(body=body).execute()
Joe Gregorio652898b2011-05-02 21:07:43 -0400114
Joe Gregorio7d791212011-05-16 21:58:52 -0700115 print 'Started training'
116 pprint.pprint(start)
Joe Gregorio652898b2011-05-02 21:07:43 -0400117
Joe Gregorio7d791212011-05-16 21:58:52 -0700118 import time
119 # Wait for the training to complete
120 while True:
Robert Kaplow49cd5f82011-08-02 13:50:04 -0400121 try:
122 # We check the training job is completed. If it is not it will return an error code.
123 status = train.get(data=FLAGS.object_name).execute()
124 # Job has completed.
125 pprint.pprint(status)
Joe Gregorio7d791212011-05-16 21:58:52 -0700126 break
Robert Kaplow49cd5f82011-08-02 13:50:04 -0400127 except apiclient.errors.HttpError as error:
128 # Training job not yet completed.
129 print 'Waiting for training to complete.'
130 time.sleep(10)
131
Joe Gregorio7d791212011-05-16 21:58:52 -0700132 print 'Training is complete'
133
134 # Now make a prediction using that training
135 body = {'input': {'csvInstance': ["mucho bueno"]}}
Robert Kaplow49cd5f82011-08-02 13:50:04 -0400136 prediction = train.predict(body=body, data=FLAGS.object_name).execute()
Joe Gregorio7d791212011-05-16 21:58:52 -0700137 print 'The prediction is:'
138 pprint.pprint(prediction)
Joe Gregorio652898b2011-05-02 21:07:43 -0400139
140
Joe Gregorio7d791212011-05-16 21:58:52 -0700141 except AccessTokenRefreshError:
142 print ("The credentials have been revoked or expired, please re-run"
143 "the application to re-authorize")
Joe Gregorio652898b2011-05-02 21:07:43 -0400144
145if __name__ == '__main__':
146 main(sys.argv)