blob: 81f7a768847b45de2e3a700fca0166f77e2ddd6c [file] [log] [blame]
Joe Gregorio652898b2011-05-02 21:07:43 -04001#!/usr/bin/python2.4
2# -*- coding: utf-8 -*-
3#
4# Copyright (C) 2010 Google Inc.
5#
6# Licensed under the Apache License, Version 2.0 (the "License");
7# you may not use this file except in compliance with the License.
8# You may obtain a copy of the License at
9#
10# http://www.apache.org/licenses/LICENSE-2.0
11#
12# Unless required by applicable law or agreed to in writing, software
13# distributed under the License is distributed on an "AS IS" BASIS,
14# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15# See the License for the specific language governing permissions and
16# limitations under the License.
17
18"""Simple command-line sample for the Google Prediction API
19
20Command-line application that trains on some data. This sample does
21the same thing as the Hello Prediction! example.
22
23Usage:
24 $ python prediction.py
25
26You can also get help on all the command-line flags the program understands
27by running:
28
29 $ python prediction.py --help
30
31To get detailed log output run:
32
33 $ python prediction.py --logging_level=DEBUG
34"""
35
36__author__ = 'jcgregorio@google.com (Joe Gregorio)'
37
38import gflags
39import httplib2
40import logging
41import pprint
42import sys
43
44from apiclient.discovery import build
45from oauth2client.file import Storage
Joe Gregorio7d791212011-05-16 21:58:52 -070046from oauth2client.client import AccessTokenRefreshError
Joe Gregorio652898b2011-05-02 21:07:43 -040047from oauth2client.client import OAuth2WebServerFlow
48from oauth2client.tools import run
49
50FLAGS = gflags.FLAGS
51
52# Set up a Flow object to be used if we need to authenticate. This
53# sample uses OAuth 2.0, and we set up the OAuth2WebServerFlow with
54# the information it needs to authenticate. Note that it is called
55# the Web Server Flow, but it can also handle the flow for native
56# applications <http://code.google.com/apis/accounts/docs/OAuth2.html#IA>
57# The client_id client_secret are copied from the API Access tab on
58# the Google APIs Console <http://code.google.com/apis/console>. When
59# creating credentials for this application be sure to choose an Application
60# type of "Installed application".
61FLOW = OAuth2WebServerFlow(
62 client_id='433807057907.apps.googleusercontent.com',
63 client_secret='jigtZpMApkRxncxikFpR+SFg',
64 scope='https://www.googleapis.com/auth/prediction',
65 user_agent='prediction-cmdline-sample/1.0')
66
67# The gflags module makes defining command-line options easy for
68# applications. Run this program with the '--help' argument to see
69# all the flags that it understands.
70gflags.DEFINE_enum('logging_level', 'ERROR',
71 ['DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL'],
72 'Set the level of logging detail.')
73
74
75def main(argv):
76 # Let the gflags module process the command-line arguments
77 try:
78 argv = FLAGS(argv)
79 except gflags.FlagsError, e:
80 print '%s\\nUsage: %s ARGS\\n%s' % (e, argv[0], FLAGS)
81 sys.exit(1)
82
83 # Set the logging according to the command-line flag
84 logging.getLogger().setLevel(getattr(logging, FLAGS.logging_level))
85
86 # If the Credentials don't exist or are invalid run through the native client
87 # flow. The Storage object will ensure that if successful the good
88 # Credentials will get written back to a file.
89 storage = Storage('prediction.dat')
90 credentials = storage.get()
91 if credentials is None or credentials.invalid:
92 credentials = run(FLOW, storage)
93
94 # Create an httplib2.Http object to handle our HTTP requests and authorize it
95 # with our good Credentials.
96 http = httplib2.Http()
97 http = credentials.authorize(http)
98
99 service = build("prediction", "v1.2", http=http)
100
Joe Gregorio7d791212011-05-16 21:58:52 -0700101 try:
Joe Gregorio652898b2011-05-02 21:07:43 -0400102
Joe Gregorio7d791212011-05-16 21:58:52 -0700103 # Name of Google Storage bucket/object that contains the training data
104 OBJECT_NAME = "apiclient-prediction-sample/prediction_models/languages"
Joe Gregorio652898b2011-05-02 21:07:43 -0400105
Joe Gregorio7d791212011-05-16 21:58:52 -0700106 # Start training on a data set
107 train = service.training()
108 start = train.insert(data=OBJECT_NAME, body={}).execute()
Joe Gregorio652898b2011-05-02 21:07:43 -0400109
Joe Gregorio7d791212011-05-16 21:58:52 -0700110 print 'Started training'
111 pprint.pprint(start)
Joe Gregorio652898b2011-05-02 21:07:43 -0400112
Joe Gregorio7d791212011-05-16 21:58:52 -0700113 import time
114 # Wait for the training to complete
115 while True:
116 status = train.get(data=OBJECT_NAME).execute()
117 pprint.pprint(status)
118 if 'RUNNING' != status['trainingStatus']:
119 break
120 print 'Waiting for training to complete.'
121 time.sleep(10)
122 print 'Training is complete'
123
124 # Now make a prediction using that training
125 body = {'input': {'csvInstance': ["mucho bueno"]}}
126 prediction = service.predict(body=body, data=OBJECT_NAME).execute()
127 print 'The prediction is:'
128 pprint.pprint(prediction)
Joe Gregorio652898b2011-05-02 21:07:43 -0400129
130
Joe Gregorio7d791212011-05-16 21:58:52 -0700131 except AccessTokenRefreshError:
132 print ("The credentials have been revoked or expired, please re-run"
133 "the application to re-authorize")
Joe Gregorio652898b2011-05-02 21:07:43 -0400134
135if __name__ == '__main__':
136 main(sys.argv)