blob: 787fdc3f9a722799c8531cd2d7e3a2009b0c9681 [file] [log] [blame]
Craig Citro15744b12015-03-02 13:34:32 -08001#!/usr/bin/env python
Joe Gregorio652898b2011-05-02 21:07:43 -04002# -*- coding: utf-8 -*-
3#
Craig Citro751b7fb2014-09-23 11:20:38 -07004# Copyright 2014 Google Inc. All Rights Reserved.
Joe Gregorio652898b2011-05-02 21:07:43 -04005#
6# Licensed under the Apache License, Version 2.0 (the "License");
7# you may not use this file except in compliance with the License.
8# You may obtain a copy of the License at
9#
10# http://www.apache.org/licenses/LICENSE-2.0
11#
12# Unless required by applicable law or agreed to in writing, software
13# distributed under the License is distributed on an "AS IS" BASIS,
14# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15# See the License for the specific language governing permissions and
16# limitations under the License.
17
18"""Simple command-line sample for the Google Prediction API
19
Joe Gregorio65826f92011-06-03 11:20:29 -040020Command-line application that trains on your input data. This sample does
21the same thing as the Hello Prediction! example. You might want to run
22the setup.sh script to load the sample data to Google Storage.
Joe Gregorio652898b2011-05-02 21:07:43 -040023
24Usage:
Antoine Picard7ba3c3f2014-05-09 14:55:58 -070025 $ python prediction.py "bucket/object" "model_id" "project_id"
Joe Gregorio652898b2011-05-02 21:07:43 -040026
27You can also get help on all the command-line flags the program understands
28by running:
29
30 $ python prediction.py --help
31
32To get detailed log output run:
33
34 $ python prediction.py --logging_level=DEBUG
35"""
INADA Naokie8d87822014-08-20 15:25:24 +090036from __future__ import print_function
Joe Gregorio652898b2011-05-02 21:07:43 -040037
Joe Gregorio968a9582012-03-07 14:52:52 -050038__author__ = ('jcgregorio@google.com (Joe Gregorio), '
39 'marccohen@google.com (Marc Cohen)')
Joe Gregorio652898b2011-05-02 21:07:43 -040040
Joe Gregorioe8391152013-06-28 01:30:57 -040041import argparse
Joe Gregorio968a9582012-03-07 14:52:52 -050042import os
Joe Gregorio652898b2011-05-02 21:07:43 -040043import pprint
44import sys
Joe Gregorio968a9582012-03-07 14:52:52 -050045import time
Joe Gregorio652898b2011-05-02 21:07:43 -040046
Antoine Picard7ba3c3f2014-05-09 14:55:58 -070047from apiclient import discovery
48from apiclient import sample_tools
Joe Gregorioe8391152013-06-28 01:30:57 -040049from oauth2client import client
Joe Gregorio652898b2011-05-02 21:07:43 -040050
Joe Gregorio968a9582012-03-07 14:52:52 -050051
52# Time to wait (in seconds) between successive checks of training status.
53SLEEP_TIME = 10
54
Joe Gregorioe8391152013-06-28 01:30:57 -040055
56# Declare command-line flags.
57argparser = argparse.ArgumentParser(add_help=False)
58argparser.add_argument('object_name',
Antoine Picard7ba3c3f2014-05-09 14:55:58 -070059 help='Full Google Storage path of csv data (ex bucket/object)')
60argparser.add_argument('model_id',
61 help='Model Id of your choosing to name trained model')
62argparser.add_argument('project_id',
Jas Sohi91673b22016-07-31 18:41:01 -070063 help='Project Id of your Google Cloud Project')
Joe Gregorioe8391152013-06-28 01:30:57 -040064
65
Joe Gregorio968a9582012-03-07 14:52:52 -050066def print_header(line):
67 '''Format and print header block sized to length of line'''
68 header_str = '='
69 header_line = header_str * len(line)
INADA Naokie8d87822014-08-20 15:25:24 +090070 print('\n' + header_line)
71 print(line)
72 print(header_line)
Joe Gregorioe8391152013-06-28 01:30:57 -040073
74
Joe Gregorio652898b2011-05-02 21:07:43 -040075def main(argv):
Antoine Picard7ba3c3f2014-05-09 14:55:58 -070076 # If you previously ran this app with an earlier version of the API
77 # or if you change the list of scopes below, revoke your app's permission
78 # here: https://accounts.google.com/IssuedAuthSubTokens
79 # Then re-run the app to re-authorize it.
Joe Gregorioe8391152013-06-28 01:30:57 -040080 service, flags = sample_tools.init(
Antoine Picard7ba3c3f2014-05-09 14:55:58 -070081 argv, 'prediction', 'v1.6', __doc__, __file__, parents=[argparser],
82 scope=(
83 'https://www.googleapis.com/auth/prediction',
84 'https://www.googleapis.com/auth/devstorage.read_only'))
Joe Gregorio652898b2011-05-02 21:07:43 -040085
Joe Gregorio7d791212011-05-16 21:58:52 -070086 try:
Joe Gregorio968a9582012-03-07 14:52:52 -050087 # Get access to the Prediction API.
Joe Gregorio968a9582012-03-07 14:52:52 -050088 papi = service.trainedmodels()
Joe Gregorio652898b2011-05-02 21:07:43 -040089
Joe Gregorio968a9582012-03-07 14:52:52 -050090 # List models.
91 print_header('Fetching list of first ten models')
Antoine Picard7ba3c3f2014-05-09 14:55:58 -070092 result = papi.list(maxResults=10, project=flags.project_id).execute()
INADA Naokie8d87822014-08-20 15:25:24 +090093 print('List results:')
Joe Gregorio968a9582012-03-07 14:52:52 -050094 pprint.pprint(result)
95
96 # Start training request on a data set.
97 print_header('Submitting model training request')
Antoine Picard7ba3c3f2014-05-09 14:55:58 -070098 body = {'id': flags.model_id, 'storageDataLocation': flags.object_name}
99 start = papi.insert(body=body, project=flags.project_id).execute()
INADA Naokie8d87822014-08-20 15:25:24 +0900100 print('Training results:')
Joe Gregorio7d791212011-05-16 21:58:52 -0700101 pprint.pprint(start)
Joe Gregorioe8391152013-06-28 01:30:57 -0400102
Joe Gregorio968a9582012-03-07 14:52:52 -0500103 # Wait for the training to complete.
104 print_header('Waiting for training to complete')
Joe Gregorio7d791212011-05-16 21:58:52 -0700105 while True:
Antoine Picard7ba3c3f2014-05-09 14:55:58 -0700106 status = papi.get(id=flags.model_id, project=flags.project_id).execute()
Joe Gregorio968a9582012-03-07 14:52:52 -0500107 state = status['trainingStatus']
INADA Naokie8d87822014-08-20 15:25:24 +0900108 print('Training state: ' + state)
Joe Gregorio968a9582012-03-07 14:52:52 -0500109 if state == 'DONE':
Joe Gregorio7d791212011-05-16 21:58:52 -0700110 break
Joe Gregorio968a9582012-03-07 14:52:52 -0500111 elif state == 'RUNNING':
112 time.sleep(SLEEP_TIME)
113 continue
114 else:
115 raise Exception('Training Error: ' + state)
Joe Gregorioe8391152013-06-28 01:30:57 -0400116
Joe Gregorio968a9582012-03-07 14:52:52 -0500117 # Job has completed.
INADA Naokie8d87822014-08-20 15:25:24 +0900118 print('Training completed:')
Joe Gregorio968a9582012-03-07 14:52:52 -0500119 pprint.pprint(status)
120 break
Robert Kaplow49cd5f82011-08-02 13:50:04 -0400121
Joe Gregorio968a9582012-03-07 14:52:52 -0500122 # Describe model.
123 print_header('Fetching model description')
Antoine Picard7ba3c3f2014-05-09 14:55:58 -0700124 result = papi.analyze(id=flags.model_id, project=flags.project_id).execute()
INADA Naokie8d87822014-08-20 15:25:24 +0900125 print('Analyze results:')
Joe Gregorio968a9582012-03-07 14:52:52 -0500126 pprint.pprint(result)
Joe Gregorio7d791212011-05-16 21:58:52 -0700127
Antoine Picard7ba3c3f2014-05-09 14:55:58 -0700128 # Make some predictions using the newly trained model.
129 print_header('Making some predictions')
130 for sample_text in ['mucho bueno', 'bonjour, mon cher ami']:
131 body = {'input': {'csvInstance': [sample_text]}}
132 result = papi.predict(
133 body=body, id=flags.model_id, project=flags.project_id).execute()
INADA Naokie8d87822014-08-20 15:25:24 +0900134 print('Prediction results for "%s"...' % sample_text)
Antoine Picard7ba3c3f2014-05-09 14:55:58 -0700135 pprint.pprint(result)
Joe Gregorio652898b2011-05-02 21:07:43 -0400136
Joe Gregorio968a9582012-03-07 14:52:52 -0500137 # Delete model.
138 print_header('Deleting model')
Antoine Picard7ba3c3f2014-05-09 14:55:58 -0700139 result = papi.delete(id=flags.model_id, project=flags.project_id).execute()
INADA Naokie8d87822014-08-20 15:25:24 +0900140 print('Model deleted.')
Joe Gregorio652898b2011-05-02 21:07:43 -0400141
Joe Gregorioe8391152013-06-28 01:30:57 -0400142 except client.AccessTokenRefreshError:
Antoine Picard7ba3c3f2014-05-09 14:55:58 -0700143 print ('The credentials have been revoked or expired, please re-run '
144 'the application to re-authorize.')
Joe Gregorio652898b2011-05-02 21:07:43 -0400145
Joe Gregorioe8391152013-06-28 01:30:57 -0400146
Joe Gregorio652898b2011-05-02 21:07:43 -0400147if __name__ == '__main__':
148 main(sys.argv)